forked from mindspore-Ecosystem/mindspore
!8802 Boilerplate code for IR Tree optimizer
From: @nsyca Reviewed-by: Signed-off-by:
This commit is contained in:
commit
bd8522aff7
|
@ -568,8 +568,8 @@ std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
|
|||
const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage,
|
||||
SentencePieceModel model_type, const std::unordered_map<std::string, std::string> ¶ms) {
|
||||
auto vocab = std::make_shared<SentencePieceVocab>();
|
||||
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode(), vocab, col_names, vocab_size, character_coverage,
|
||||
model_type, params);
|
||||
auto ds = std::make_shared<BuildSentenceVocabNode>(IRNode()->DeepCopy(), vocab, col_names, vocab_size,
|
||||
character_coverage, model_type, params);
|
||||
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
Status rc = runtime_context->Init();
|
||||
|
@ -600,8 +600,8 @@ std::shared_ptr<Vocab> Dataset::BuildVocab(const std::vector<std::string> &colum
|
|||
const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
|
||||
const std::vector<std::string> &special_tokens, bool special_first) {
|
||||
auto vocab = std::make_shared<Vocab>();
|
||||
auto ds =
|
||||
std::make_shared<BuildVocabNode>(IRNode(), vocab, columns, freq_range, top_k, special_tokens, special_first);
|
||||
auto ds = std::make_shared<BuildVocabNode>(IRNode()->DeepCopy(), vocab, columns, freq_range, top_k, special_tokens,
|
||||
special_first);
|
||||
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
Status rc = runtime_context->Init();
|
||||
|
|
|
@ -190,13 +190,12 @@ std::shared_ptr<SamplerRT> PKSamplerObj::Build() {
|
|||
return sampler;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// PreBuiltOperation
|
||||
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler)
|
||||
: sp_(std::move(sampler)), sp_minddataset_(nullptr) {}
|
||||
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) : sp_(std::move(sampler)) {}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler)
|
||||
: sp_(nullptr), sp_minddataset_(std::move(sampler)) {}
|
||||
: sp_minddataset_(std::move(sampler)) {}
|
||||
#endif
|
||||
|
||||
bool PreBuiltSamplerObj::ValidateParams() { return true; }
|
||||
|
@ -207,6 +206,13 @@ std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { return sp_; }
|
|||
std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; }
|
||||
#endif
|
||||
|
||||
std::shared_ptr<SamplerObj> PreBuiltSamplerObj::Copy() {
|
||||
#ifndef ENABLE_ANDROID
|
||||
if (sp_minddataset_ != nullptr) return std::make_shared<PreBuiltSamplerObj>(sp_minddataset_);
|
||||
#endif
|
||||
return std::make_shared<PreBuiltSamplerObj>(sp_);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
|
||||
// runtime mindrecord sampler object
|
||||
|
|
|
@ -30,8 +30,6 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
TensorOperation::TensorOperation() {}
|
||||
|
||||
/* ####################################### Validator Functions ############################################ */
|
||||
Status ValidateVectorFillvalue(const std::string &transform_name, const std::vector<uint8_t> &fill_value) {
|
||||
if (fill_value.empty() || (fill_value.size() != 1 && fill_value.size() != 3)) {
|
||||
|
@ -231,7 +229,7 @@ std::shared_ptr<TensorOp> PreBuiltOperation::Build() { return op_; }
|
|||
|
||||
// RandomApplyOperation
|
||||
RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob)
|
||||
: transforms_(transforms), prob_(prob) {}
|
||||
: TensorOperation(true), transforms_(transforms), prob_(prob) {}
|
||||
|
||||
Status RandomApplyOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomApply", transforms_));
|
||||
|
@ -248,7 +246,7 @@ std::shared_ptr<TensorOp> RandomApplyOperation::Build() {
|
|||
|
||||
// RandomChoiceOperation
|
||||
RandomChoiceOperation::RandomChoiceOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms)
|
||||
: transforms_(transforms) {}
|
||||
: TensorOperation(true), transforms_(transforms) {}
|
||||
|
||||
Status RandomChoiceOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateVectorTransforms("RandomChoice", transforms_));
|
||||
|
|
|
@ -734,7 +734,9 @@ RandomAffineOperation::RandomAffineOperation(const std::vector<float_t> °rees
|
|||
scale_range_(scale_range),
|
||||
shear_ranges_(shear_ranges),
|
||||
interpolation_(interpolation),
|
||||
fill_value_(fill_value) {}
|
||||
fill_value_(fill_value) {
|
||||
random_op_ = true;
|
||||
}
|
||||
|
||||
Status RandomAffineOperation::ValidateParams() {
|
||||
// Degrees
|
||||
|
@ -867,7 +869,7 @@ std::shared_ptr<TensorOp> RandomAffineOperation::Build() {
|
|||
}
|
||||
|
||||
// RandomColorOperation.
|
||||
RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) {}
|
||||
RandomColorOperation::RandomColorOperation(float t_lb, float t_ub) : t_lb_(t_lb), t_ub_(t_ub) { random_op_ = true; }
|
||||
|
||||
Status RandomColorOperation::ValidateParams() {
|
||||
// Do some input validation.
|
||||
|
@ -891,7 +893,9 @@ Status RandomColorOperation::ValidateParams() {
|
|||
// RandomColorAdjustOperation.
|
||||
RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector<float> brightness, std::vector<float> contrast,
|
||||
std::vector<float> saturation, std::vector<float> hue)
|
||||
: brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) {}
|
||||
: brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) {
|
||||
random_op_ = true;
|
||||
}
|
||||
|
||||
Status RandomColorAdjustOperation::ValidateParams() {
|
||||
// brightness
|
||||
|
@ -1012,11 +1016,14 @@ std::shared_ptr<TensorOp> RandomColorAdjustOperation::Build() {
|
|||
// RandomCropOperation
|
||||
RandomCropOperation::RandomCropOperation(std::vector<int32_t> size, std::vector<int32_t> padding, bool pad_if_needed,
|
||||
std::vector<uint8_t> fill_value, BorderType padding_mode)
|
||||
: size_(size),
|
||||
: TensorOperation(true),
|
||||
size_(size),
|
||||
padding_(padding),
|
||||
pad_if_needed_(pad_if_needed),
|
||||
fill_value_(fill_value),
|
||||
padding_mode_(padding_mode) {}
|
||||
padding_mode_(padding_mode) {
|
||||
random_op_ = true;
|
||||
}
|
||||
|
||||
Status RandomCropOperation::ValidateParams() {
|
||||
// size
|
||||
|
@ -1083,7 +1090,12 @@ std::shared_ptr<TensorOp> RandomCropOperation::Build() {
|
|||
RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(std::vector<int32_t> size, std::vector<float> scale,
|
||||
std::vector<float> ratio,
|
||||
InterpolationMode interpolation, int32_t max_attempts)
|
||||
: size_(size), scale_(scale), ratio_(ratio), interpolation_(interpolation), max_attempts_(max_attempts) {}
|
||||
: TensorOperation(true),
|
||||
size_(size),
|
||||
scale_(scale),
|
||||
ratio_(ratio),
|
||||
interpolation_(interpolation),
|
||||
max_attempts_(max_attempts) {}
|
||||
|
||||
Status RandomCropDecodeResizeOperation::ValidateParams() {
|
||||
// size
|
||||
|
@ -1176,7 +1188,8 @@ std::shared_ptr<TensorOp> RandomCropDecodeResizeOperation::Build() {
|
|||
RandomCropWithBBoxOperation::RandomCropWithBBoxOperation(std::vector<int32_t> size, std::vector<int32_t> padding,
|
||||
bool pad_if_needed, std::vector<uint8_t> fill_value,
|
||||
BorderType padding_mode)
|
||||
: size_(size),
|
||||
: TensorOperation(true),
|
||||
size_(size),
|
||||
padding_(padding),
|
||||
pad_if_needed_(pad_if_needed),
|
||||
fill_value_(fill_value),
|
||||
|
@ -1245,7 +1258,8 @@ std::shared_ptr<TensorOp> RandomCropWithBBoxOperation::Build() {
|
|||
}
|
||||
|
||||
// RandomHorizontalFlipOperation
|
||||
RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) : probability_(probability) {}
|
||||
RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability)
|
||||
: TensorOperation(true), probability_(probability) {}
|
||||
|
||||
Status RandomHorizontalFlipOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateProbability("RandomHorizontalFlip", probability_));
|
||||
|
@ -1260,7 +1274,7 @@ std::shared_ptr<TensorOp> RandomHorizontalFlipOperation::Build() {
|
|||
|
||||
// RandomHorizontalFlipWithBBoxOperation
|
||||
RandomHorizontalFlipWithBBoxOperation::RandomHorizontalFlipWithBBoxOperation(float probability)
|
||||
: probability_(probability) {}
|
||||
: TensorOperation(true), probability_(probability) {}
|
||||
|
||||
Status RandomHorizontalFlipWithBBoxOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateProbability("RandomHorizontalFlipWithBBox", probability_));
|
||||
|
@ -1275,7 +1289,8 @@ std::shared_ptr<TensorOp> RandomHorizontalFlipWithBBoxOperation::Build() {
|
|||
}
|
||||
|
||||
// RandomPosterizeOperation
|
||||
RandomPosterizeOperation::RandomPosterizeOperation(const std::vector<uint8_t> &bit_range) : bit_range_(bit_range) {}
|
||||
RandomPosterizeOperation::RandomPosterizeOperation(const std::vector<uint8_t> &bit_range)
|
||||
: TensorOperation(true), bit_range_(bit_range) {}
|
||||
|
||||
Status RandomPosterizeOperation::ValidateParams() {
|
||||
if (bit_range_.size() != 2) {
|
||||
|
@ -1309,7 +1324,7 @@ std::shared_ptr<TensorOp> RandomPosterizeOperation::Build() {
|
|||
}
|
||||
|
||||
// RandomResizeOperation
|
||||
RandomResizeOperation::RandomResizeOperation(std::vector<int32_t> size) : size_(size) {}
|
||||
RandomResizeOperation::RandomResizeOperation(std::vector<int32_t> size) : TensorOperation(true), size_(size) {}
|
||||
|
||||
Status RandomResizeOperation::ValidateParams() {
|
||||
// size
|
||||
|
@ -1343,7 +1358,8 @@ std::shared_ptr<TensorOp> RandomResizeOperation::Build() {
|
|||
}
|
||||
|
||||
// RandomResizeWithBBoxOperation
|
||||
RandomResizeWithBBoxOperation::RandomResizeWithBBoxOperation(std::vector<int32_t> size) : size_(size) {}
|
||||
RandomResizeWithBBoxOperation::RandomResizeWithBBoxOperation(std::vector<int32_t> size)
|
||||
: TensorOperation(true), size_(size) {}
|
||||
|
||||
Status RandomResizeWithBBoxOperation::ValidateParams() {
|
||||
// size
|
||||
|
@ -1380,7 +1396,12 @@ std::shared_ptr<TensorOp> RandomResizeWithBBoxOperation::Build() {
|
|||
RandomResizedCropOperation::RandomResizedCropOperation(std::vector<int32_t> size, std::vector<float> scale,
|
||||
std::vector<float> ratio, InterpolationMode interpolation,
|
||||
int32_t max_attempts)
|
||||
: size_(size), scale_(scale), ratio_(ratio), interpolation_(interpolation), max_attempts_(max_attempts) {}
|
||||
: TensorOperation(true),
|
||||
size_(size),
|
||||
scale_(scale),
|
||||
ratio_(ratio),
|
||||
interpolation_(interpolation),
|
||||
max_attempts_(max_attempts) {}
|
||||
|
||||
Status RandomResizedCropOperation::ValidateParams() {
|
||||
// size
|
||||
|
@ -1536,7 +1557,8 @@ std::shared_ptr<TensorOp> RandomResizedCropWithBBoxOperation::Build() {
|
|||
RandomRotationOperation::RandomRotationOperation(std::vector<float> degrees, InterpolationMode interpolation_mode,
|
||||
bool expand, std::vector<float> center,
|
||||
std::vector<uint8_t> fill_value)
|
||||
: degrees_(degrees),
|
||||
: TensorOperation(true),
|
||||
degrees_(degrees),
|
||||
interpolation_mode_(interpolation_mode),
|
||||
expand_(expand),
|
||||
center_(center),
|
||||
|
@ -1603,7 +1625,7 @@ std::shared_ptr<TensorOp> RandomRotationOperation::Build() {
|
|||
// RandomSelectSubpolicyOperation.
|
||||
RandomSelectSubpolicyOperation::RandomSelectSubpolicyOperation(
|
||||
std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy)
|
||||
: policy_(policy) {}
|
||||
: TensorOperation(true), policy_(policy) {}
|
||||
|
||||
Status RandomSelectSubpolicyOperation::ValidateParams() {
|
||||
if (policy_.empty()) {
|
||||
|
@ -1650,7 +1672,8 @@ std::shared_ptr<TensorOp> RandomSelectSubpolicyOperation::Build() {
|
|||
}
|
||||
|
||||
// Function to create RandomSharpness.
|
||||
RandomSharpnessOperation::RandomSharpnessOperation(std::vector<float> degrees) : degrees_(degrees) {}
|
||||
RandomSharpnessOperation::RandomSharpnessOperation(std::vector<float> degrees)
|
||||
: TensorOperation(true), degrees_(degrees) {}
|
||||
|
||||
Status RandomSharpnessOperation::ValidateParams() {
|
||||
if (degrees_.size() != 2 || degrees_[0] < 0 || degrees_[1] < 0) {
|
||||
|
@ -1674,7 +1697,8 @@ std::shared_ptr<TensorOp> RandomSharpnessOperation::Build() {
|
|||
}
|
||||
|
||||
// RandomSolarizeOperation.
|
||||
RandomSolarizeOperation::RandomSolarizeOperation(std::vector<uint8_t> threshold) : threshold_(threshold) {}
|
||||
RandomSolarizeOperation::RandomSolarizeOperation(std::vector<uint8_t> threshold)
|
||||
: TensorOperation(true), threshold_(threshold) {}
|
||||
|
||||
Status RandomSolarizeOperation::ValidateParams() {
|
||||
if (threshold_.size() != 2) {
|
||||
|
@ -1705,7 +1729,8 @@ std::shared_ptr<TensorOp> RandomSolarizeOperation::Build() {
|
|||
}
|
||||
|
||||
// RandomVerticalFlipOperation
|
||||
RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) : probability_(probability) {}
|
||||
RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability)
|
||||
: TensorOperation(true), probability_(probability) {}
|
||||
|
||||
Status RandomVerticalFlipOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateProbability("RandomVerticalFlip", probability_));
|
||||
|
@ -1720,7 +1745,7 @@ std::shared_ptr<TensorOp> RandomVerticalFlipOperation::Build() {
|
|||
|
||||
// RandomVerticalFlipWithBBoxOperation
|
||||
RandomVerticalFlipWithBBoxOperation::RandomVerticalFlipWithBBoxOperation(float probability)
|
||||
: probability_(probability) {}
|
||||
: TensorOperation(true), probability_(probability) {}
|
||||
|
||||
Status RandomVerticalFlipWithBBoxOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateProbability("RandomVerticalFlipWithBBox", probability_));
|
||||
|
|
|
@ -9,11 +9,13 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
|
|||
build_sentence_piece_vocab_node.cc
|
||||
build_vocab_node.cc
|
||||
concat_node.cc
|
||||
epoch_ctrl_node.cc
|
||||
filter_node.cc
|
||||
map_node.cc
|
||||
project_node.cc
|
||||
rename_node.cc
|
||||
repeat_node.cc
|
||||
root_node.cc
|
||||
shuffle_node.cc
|
||||
skip_node.cc
|
||||
sync_wait_node.cc
|
||||
|
|
|
@ -43,14 +43,29 @@ BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, boo
|
|||
batch_size_func_(batch_size_func),
|
||||
batch_map_func_(batch_map_func),
|
||||
pad_map_(pad_map) {
|
||||
this->children.push_back(child);
|
||||
this->AddChild(child);
|
||||
}
|
||||
#endif
|
||||
|
||||
// constructor #2, called by C++ API
|
||||
BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder)
|
||||
: batch_size_(batch_size), drop_remainder_(drop_remainder), pad_(false) {
|
||||
this->children.push_back(child);
|
||||
this->AddChild(child);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> BatchNode::Copy() {
|
||||
#ifdef ENABLE_PYTHON
|
||||
auto node = std::make_shared<BatchNode>(nullptr, batch_size_, drop_remainder_, pad_, in_col_names_, out_col_names_,
|
||||
col_order_, batch_size_func_, batch_map_func_, pad_map_);
|
||||
#else
|
||||
auto node = std::make_shared<BatchNode>(nullptr, batch_size_, drop_remainder_);
|
||||
#endif
|
||||
return node;
|
||||
}
|
||||
|
||||
void BatchNode::Print(std::ostream &out) const {
|
||||
out << Name() + "(batch_size:" + std::to_string(batch_size_) +
|
||||
" drop_remainder:" + (drop_remainder_ ? "true" : "false") + ")";
|
||||
}
|
||||
|
||||
Status BatchNode::ValidateParams() {
|
||||
|
|
|
@ -44,6 +44,18 @@ class BatchNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~BatchNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kBatchNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -41,7 +41,17 @@ BucketBatchByLengthNode::BucketBatchByLengthNode(
|
|||
pad_info_(pad_info),
|
||||
pad_to_bucket_boundary_(pad_to_bucket_boundary),
|
||||
drop_remainder_(drop_remainder) {
|
||||
this->children.push_back(child);
|
||||
this->AddChild(child);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> BucketBatchByLengthNode::Copy() {
|
||||
auto node = std::make_shared<BucketBatchByLengthNode>(nullptr, column_names_, bucket_boundaries_, bucket_batch_sizes_,
|
||||
element_length_function_, pad_info_, pad_to_bucket_boundary_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void BucketBatchByLengthNode::Print(std::ostream &out) const {
|
||||
out << Name() + "(columns:" + PrintColumns(column_names_) + ",...)";
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() {
|
||||
|
|
|
@ -40,6 +40,18 @@ class BucketBatchByLengthNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~BucketBatchByLengthNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kBucketBatchByLengthNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/build_sentence_piece_vocab_op.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -38,7 +39,18 @@ BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr<DatasetNode> chil
|
|||
character_coverage_(character_coverage),
|
||||
model_type_(model_type),
|
||||
params_(params) {
|
||||
this->children.push_back(child);
|
||||
this->AddChild(child);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> BuildSentenceVocabNode::Copy() {
|
||||
auto node = std::make_shared<BuildSentenceVocabNode>(nullptr, vocab_, col_names_, vocab_size_, character_coverage_,
|
||||
model_type_, params_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void BuildSentenceVocabNode::Print(std::ostream &out) const {
|
||||
out << Name() + "<vocab>," + "columns:" + PrintColumns(col_names_) + ",vocab_size:" + std::to_string(vocab_size_) +
|
||||
",...)";
|
||||
}
|
||||
|
||||
// Function to build BuildSentenceVocabNode
|
||||
|
@ -81,5 +93,16 @@ Status BuildSentenceVocabNode::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status BuildSentenceVocabNode::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->Visit(shared_from_base<BuildSentenceVocabNode>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status BuildSentenceVocabNode::AcceptAfter(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->VisitAfter(shared_from_base<BuildSentenceVocabNode>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,6 +38,18 @@ class BuildSentenceVocabNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~BuildSentenceVocabNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kBuildSentencePieceVocabNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
@ -46,6 +58,18 @@ class BuildSentenceVocabNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<SentencePieceVocab> vocab_;
|
||||
std::vector<std::string> col_names_;
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
|
||||
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -36,7 +36,17 @@ BuildVocabNode::BuildVocabNode(std::shared_ptr<DatasetNode> child, std::shared_p
|
|||
top_k_(top_k),
|
||||
special_tokens_(special_tokens),
|
||||
special_first_(special_first) {
|
||||
this->children.push_back(child);
|
||||
this->AddChild(child);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> BuildVocabNode::Copy() {
|
||||
auto node =
|
||||
std::make_shared<BuildVocabNode>(nullptr, vocab_, columns_, freq_range_, top_k_, special_tokens_, special_first_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void BuildVocabNode::Print(std::ostream &out) const {
|
||||
out << Name() + "(<vocab>," + "columns:" + PrintColumns(columns_) + ",...)";
|
||||
}
|
||||
|
||||
// Function to build BuildVocabNode
|
||||
|
@ -78,5 +88,16 @@ Status BuildVocabNode::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status BuildVocabNode::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->Visit(shared_from_base<BuildVocabNode>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status BuildVocabNode::AcceptAfter(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->VisitAfter(shared_from_base<BuildVocabNode>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,6 +37,18 @@ class BuildVocabNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~BuildVocabNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kBuildVocabNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
@ -45,6 +57,18 @@ class BuildVocabNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<Vocab> vocab_;
|
||||
std::vector<std::string> columns_;
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/concat_op.h"
|
||||
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -35,17 +35,25 @@ ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets
|
|||
: sampler_(sampler),
|
||||
children_flag_and_nums_(children_flag_and_nums),
|
||||
children_start_end_index_(children_start_end_index) {
|
||||
this->children = datasets;
|
||||
for (auto const &child : datasets) AddChild(child);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> ConcatNode::Copy() {
|
||||
// create an empty vector to copy a concat
|
||||
auto node = std::make_shared<ConcatNode>(std::vector<std::shared_ptr<DatasetNode>>());
|
||||
return node;
|
||||
}
|
||||
|
||||
void ConcatNode::Print(std::ostream &out) const { out << Name(); }
|
||||
|
||||
Status ConcatNode::ValidateParams() {
|
||||
if (children.size() < 2) {
|
||||
if (children_.size() < 2) {
|
||||
std::string err_msg = "ConcatNode: concatenated datasets are not specified.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (find(children.begin(), children.end(), nullptr) != children.end()) {
|
||||
if (find(children_.begin(), children_.end(), nullptr) != children_.end()) {
|
||||
std::string err_msg = "ConcatNode: concatenated datasets should not be null.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
|
@ -73,5 +81,16 @@ std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() {
|
|||
return node_ops;
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status ConcatNode::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->Visit(shared_from_base<ConcatNode>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status ConcatNode::AcceptAfter(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->VisitAfter(shared_from_base<ConcatNode>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,6 +38,18 @@ class ConcatNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~ConcatNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kConcatNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
@ -50,6 +62,18 @@ class ConcatNode : public DatasetNode {
|
|||
std::shared_ptr<SamplerObj> sampler_;
|
||||
std::vector<std::pair<int, int>> children_flag_and_nums_;
|
||||
std::vector<std::pair<int, int>> children_start_end_index_;
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -233,14 +233,92 @@ std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) {
|
|||
return shared_from_this();
|
||||
}
|
||||
|
||||
DatasetNode::DatasetNode() {
|
||||
DatasetNode::DatasetNode() : cache_(nullptr), parent_(nullptr), children_({}) {
|
||||
// Fetch some default value from config manager
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
num_workers_ = cfg->num_parallel_workers();
|
||||
rows_per_buffer_ = cfg->rows_per_buffer();
|
||||
connector_que_size_ = cfg->op_connector_size();
|
||||
worker_connector_size_ = cfg->worker_connector_size();
|
||||
build_status = Status::OK(); // remove me after changing return val of Build()
|
||||
}
|
||||
|
||||
// this function will preform a deep copy of current node (and its descendants), the parent* pointer will not be copied
|
||||
std::shared_ptr<DatasetNode> DatasetNode::DeepCopy() {
|
||||
std::shared_ptr<DatasetNode> new_node = this->Copy();
|
||||
for (const auto &child : children_) {
|
||||
new_node->AddChild(child->DeepCopy());
|
||||
}
|
||||
return new_node;
|
||||
}
|
||||
|
||||
std::string DatasetNode::PrintColumns(const std::vector<std::string> &columns) const {
|
||||
std::string me;
|
||||
if (columns.empty()) {
|
||||
me = "<nil>";
|
||||
} else {
|
||||
me = "[";
|
||||
auto i = 0;
|
||||
for (auto it = columns.begin(); it < columns.end(); ++it, ++i) {
|
||||
me += *it;
|
||||
if (i < columns.size() - 1) {
|
||||
me += ", ";
|
||||
} else {
|
||||
me += "]";
|
||||
}
|
||||
}
|
||||
}
|
||||
return me;
|
||||
}
|
||||
|
||||
void DatasetNode::PrintTree(std::ostream &out) const {
|
||||
int level = 0;
|
||||
PrintNode(out, &level);
|
||||
}
|
||||
|
||||
void DatasetNode::PrintNode(std::ostream &out, int *level) const {
|
||||
const std::string prefix = "+-";
|
||||
const std::string indent = " ";
|
||||
out << prefix;
|
||||
Print(out);
|
||||
for (const auto &c : this->Children()) {
|
||||
out << '\n';
|
||||
++(*level);
|
||||
for (auto i = 0; i < *level; i++) {
|
||||
out << indent;
|
||||
}
|
||||
c->PrintNode(out, level);
|
||||
--(*level);
|
||||
}
|
||||
}
|
||||
|
||||
// Add a node as a child, node's parent needs to be nullptr
|
||||
// this function will allow child to be a nullptr, in which case it will simply skip
|
||||
void DatasetNode::AddChild(std::shared_ptr<DatasetNode> child) {
|
||||
if (child != nullptr && child->parent_ == nullptr) {
|
||||
children_.push_back(child);
|
||||
child->parent_ = this;
|
||||
} else if (child != nullptr) {
|
||||
MS_LOG(WARNING) << "DatasetNode::AddChild() Fail" + child->Name() + "'s parent isn't a nullptr.";
|
||||
}
|
||||
}
|
||||
|
||||
// Remove this node from its parent. Add the child of this node to its parent.
|
||||
// for now, this remove is limited to node with a single child or no child
|
||||
Status DatasetNode::Remove() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(parent_ != nullptr, "Cannot remove root or a node without parent.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(children_.size() < 2, "Cannot remove node with more than 1 child.");
|
||||
if (children_.empty()) { // I am a leaf node, remove me from my parent's children list
|
||||
parent_->children_.erase(std::remove(parent_->children_.begin(), parent_->children_.end(), shared_from_this()),
|
||||
parent_->children_.end()); // removal using "erase remove idiom"
|
||||
} else { // replace my position in my parent's children list with my single child
|
||||
auto itr = std::find(parent_->children_.begin(), parent_->children_.end(), shared_from_this());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(itr != parent_->children_.end(), "I am not in my parent's children list.");
|
||||
children_[0]->parent_ = parent_; // set my single child's parent ptr to my parent
|
||||
*itr = std::move(children_[0]); // replace me in my parent's children list with my single child
|
||||
children_.clear(); // release my single child from my children list
|
||||
}
|
||||
parent_ = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// In DFS tree traversal, each node is visited twice. Accept is called on the first visit.
|
||||
|
@ -255,13 +333,25 @@ Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) {
|
|||
// This method will only be called if its derived class does not implement one.
|
||||
return p->VisitAfter(shared_from_this(), modified);
|
||||
}
|
||||
|
||||
Status DatasetNode::GetShardId(int32_t *shard_id) {
|
||||
if (!Children().empty()) {
|
||||
// Get shard id from the child node
|
||||
return Children()[0]->GetShardId(shard_id);
|
||||
} else {
|
||||
RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node");
|
||||
RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node: " + Name() + "\n");
|
||||
}
|
||||
}
|
||||
// Visitor accepting method for NodePass
|
||||
Status SourceNode::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->Visit(shared_from_base<SourceNode>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status SourceNode::AcceptAfter(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->VisitAfter(shared_from_base<SourceNode>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,6 +42,45 @@ class NodePass;
|
|||
} \
|
||||
} while (false)
|
||||
|
||||
// Names for non-leaf IR node
|
||||
constexpr char kBatchNode[] = "Batch";
|
||||
constexpr char kBucketBatchByLengthNode[] = "BucketBatchByLength";
|
||||
constexpr char kBuildSentencePieceVocabNode[] = "BuildSentencePieceVocab";
|
||||
constexpr char kBuildVocabNode[] = "BuildVocab";
|
||||
constexpr char kConcatNode[] = "Concat";
|
||||
constexpr char kDatasetNode[] = "Dataset";
|
||||
constexpr char kEpochCtrlNode[] = "EpochCtrl";
|
||||
constexpr char kFilterNode[] = "Filter";
|
||||
constexpr char kMapNode[] = "Map";
|
||||
constexpr char kProjectNode[] = "Project";
|
||||
constexpr char kRenameNode[] = "Rename";
|
||||
constexpr char kRepeatNode[] = "Repeat";
|
||||
constexpr char kRootNode[] = "Top";
|
||||
constexpr char kShuffleNode[] = "Shuffle";
|
||||
constexpr char kSkipNode[] = "Skip";
|
||||
constexpr char kSyncWaitNode[] = "SyncWait";
|
||||
constexpr char kTakeNode[] = "Take";
|
||||
constexpr char kTransferNode[] = "Transfer";
|
||||
constexpr char kZipNode[] = "Zip";
|
||||
|
||||
// Names for leaf IR node
|
||||
constexpr char kAlbumNode[] = "AlbumDataset";
|
||||
constexpr char kCelebANode[] = "CelebADataset";
|
||||
constexpr char kCifar100Node[] = "Cifar100Dataset";
|
||||
constexpr char kCifar10Node[] = "Cifar10Dataset";
|
||||
constexpr char kCLUENode[] = "CLUEDataset";
|
||||
constexpr char kCocoNode[] = "CocoDataset";
|
||||
constexpr char kCSVNode[] = "CSVDataset";
|
||||
constexpr char kGeneratorNode[] = "GeneratorDataset";
|
||||
constexpr char kImageFolderNode[] = "ImageFolderDataset";
|
||||
constexpr char kManifestNode[] = "ManifestDataset";
|
||||
constexpr char kMindDataNode[] = "MindDataDataset";
|
||||
constexpr char kMnistNode[] = "MnistDataset";
|
||||
constexpr char kRandomNode[] = "RandomDataset";
|
||||
constexpr char kTextFileNode[] = "TextFileDataset";
|
||||
constexpr char kTFRecordNode[] = "TFRecordDataset";
|
||||
constexpr char kVOCNode[] = "VOCDataset";
|
||||
|
||||
Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
|
||||
int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op);
|
||||
|
||||
|
@ -75,6 +114,7 @@ Status ValidateDatasetDirParam(const std::string &dataset_name, std::string data
|
|||
/// \return Shared pointer to the current Sampler.
|
||||
std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id);
|
||||
|
||||
// The base class of all IR nodes
|
||||
class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
@ -87,6 +127,36 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
/// \brief Destructor
|
||||
~DatasetNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
virtual std::string Name() const = 0;
|
||||
|
||||
/// \brief Pure virtual function to print the description
|
||||
/// \param out - The output stream to write output to
|
||||
virtual void Print(std::ostream &out) const = 0;
|
||||
|
||||
/// \brief Pure virtual function to make a new copy of the node
|
||||
/// \return The new copy of the node
|
||||
virtual std::shared_ptr<DatasetNode> Copy() = 0;
|
||||
|
||||
/// \brief Print the IR tree to output stream
|
||||
/// \param out - The output stream to write output to
|
||||
void PrintTree(std::ostream &out) const;
|
||||
|
||||
/// \brief << Stream output operator overload
|
||||
/// \notes This allows you to write the debug print info using stream operators
|
||||
/// \param out - reference to the output stream being overloaded
|
||||
/// \param dO - reference to the DatasetOp to display
|
||||
/// \return - the output stream must be returned
|
||||
friend std::ostream &operator<<(std::ostream &out, const DatasetNode &node) {
|
||||
node.PrintTree(out);
|
||||
return out;
|
||||
}
|
||||
|
||||
/// \brief Make a new copy of the tree from the current node
|
||||
/// \return The new copy of the tree
|
||||
std::shared_ptr<DatasetNode> DeepCopy();
|
||||
|
||||
/// \brief Pure virtual function to convert a DatasetNode class into a runtime dataset object
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
virtual std::vector<std::shared_ptr<DatasetOp>> Build() = 0;
|
||||
|
@ -95,17 +165,38 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
virtual Status ValidateParams() = 0;
|
||||
|
||||
const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children; }
|
||||
|
||||
/// \brief Pure virtual function for derived class to get the shard id of specific node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
virtual Status GetShardId(int32_t *shard_id);
|
||||
|
||||
/// \brief Getter function for child nodes
|
||||
/// \return Child nodes
|
||||
const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children_; }
|
||||
|
||||
/// \brief Establish the parent-child relationship between this node and its child.
|
||||
void AddChild(std::shared_ptr<DatasetNode> child);
|
||||
|
||||
/// \brief detach this node from its parent, add its child (if any) to its parent
|
||||
/// \return error code, return error if node has more than 1 children
|
||||
Status Remove();
|
||||
|
||||
/// \brief Check if this node has cache
|
||||
/// \return True if the data of this node will be cached
|
||||
const bool IsCached() const { return (cache_ != nullptr); }
|
||||
|
||||
/// \brief Setter function for runtime number of workers
|
||||
/// \param[in] num_workers The number of threads in this operator
|
||||
/// \return Shared pointer to the original object
|
||||
std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers);
|
||||
|
||||
/// \brief A helper templated function for casting "this" pointer to shared_ptr<derived>
|
||||
/// Similar to shared_from_this, except this one will give you the derived class as shared_ptr
|
||||
/// \return A shared_ptr casted to the derived class
|
||||
template <typename Derived>
|
||||
std::shared_ptr<Derived> shared_from_base() {
|
||||
return std::static_pointer_cast<Derived>(shared_from_this());
|
||||
}
|
||||
|
||||
/// \brief Base method for NodePass visit. A tree walk consists of walking down the tree and also walking back up
|
||||
/// in a depth-first order. Accept is the node visit on the way down, whereas AcceptAfter is the node
|
||||
/// visit on the way back up the tree after its descendants are visited.
|
||||
|
@ -129,17 +220,123 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
Status BuildStatus() { return build_status; }
|
||||
|
||||
protected:
|
||||
std::vector<std::shared_ptr<DatasetNode>> children;
|
||||
std::vector<std::shared_ptr<DatasetNode>> children_;
|
||||
DatasetNode *parent_;
|
||||
std::shared_ptr<DatasetCache> cache_;
|
||||
Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops);
|
||||
|
||||
int32_t num_workers_;
|
||||
int32_t rows_per_buffer_;
|
||||
int32_t connector_que_size_;
|
||||
int32_t worker_connector_size_;
|
||||
Status build_status; // remove me after changing return val of Build()
|
||||
std::string PrintColumns(const std::vector<std::string> &columns) const;
|
||||
Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops);
|
||||
void PrintNode(std::ostream &out, int *level) const;
|
||||
};
|
||||
|
||||
// SourceNode represents the leaf nodes of a pipeline where the data is pulled into.
|
||||
class SourceNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
SourceNode() : DatasetNode() {}
|
||||
|
||||
/// \brief Constructor that initializes the cache
|
||||
/// \param dataset_cache DatasetCache
|
||||
explicit SourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode(dataset_cache) {}
|
||||
|
||||
/// \brief Destructor
|
||||
~SourceNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
virtual std::string Name() const = 0;
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Check if this node is a mappable dataset. Only applicable to leaf nodes
|
||||
/// \return True if the dataset represented by this node is a mappable dataset
|
||||
const bool IsMappable() const { return mappable_; }
|
||||
|
||||
protected:
|
||||
bool mappable_;
|
||||
};
|
||||
|
||||
// MappableSourceNode represents the leaf nodes that can be randomly accessed with indexes.
|
||||
class MappableSourceNode : public SourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
MappableSourceNode() : SourceNode() { mappable_ = true; }
|
||||
|
||||
/// \brief Constructor that initializes the cache
|
||||
/// \param dataset_cache DatasetCache
|
||||
explicit MappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) {
|
||||
mappable_ = true;
|
||||
}
|
||||
|
||||
/// \brief Destructor
|
||||
~MappableSourceNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
virtual std::string Name() const = 0;
|
||||
};
|
||||
|
||||
// NonMappableSourceNode represents the leaf nodes that can not be randomly accessed.
|
||||
class NonMappableSourceNode : public SourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
NonMappableSourceNode() : SourceNode() { mappable_ = false; }
|
||||
|
||||
/// \brief Constructor that initializes the cache
|
||||
/// \param dataset_cache DatasetCache
|
||||
explicit NonMappableSourceNode(const std::shared_ptr<DatasetCache> &dataset_cache) : SourceNode(dataset_cache) {
|
||||
mappable_ = false;
|
||||
}
|
||||
|
||||
/// \brief Destructor
|
||||
~NonMappableSourceNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
virtual std::string Name() const = 0;
|
||||
};
|
||||
|
||||
// NonLeafNode represents operations over data in a pipeline.
|
||||
class NonLeafNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
NonLeafNode() = default;
|
||||
|
||||
/// \brief Destructor
|
||||
~NonLeafNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
virtual std::string Name() const = 0;
|
||||
};
|
||||
|
||||
// SinkNode represents the end node of a pipeline where the data is pushed out
|
||||
class SinkNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
SinkNode() = default;
|
||||
|
||||
/// \brief Destructor
|
||||
~SinkNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
virtual std::string Name() const = 0;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/epoch_ctrl_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// Constructor for EpochCtrlNode
|
||||
EpochCtrlNode::EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs) : num_epochs_(num_epochs) {
|
||||
// The root node's parent must set to null pointer.
|
||||
this->AddChild(child);
|
||||
}
|
||||
std::shared_ptr<DatasetNode> EpochCtrlNode::Copy() {
|
||||
auto node = std::make_shared<EpochCtrlNode>(nullptr, this->num_epochs_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void EpochCtrlNode::Print(std::ostream &out) const { out << Name() + "(epoch:" + std::to_string(num_epochs_) + ")"; }
|
||||
|
||||
// Function to build the EpochCtrlOp
|
||||
std::vector<std::shared_ptr<DatasetOp>> EpochCtrlNode::Build() {
|
||||
// A dummy vector
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
node_ops.push_back(std::make_shared<EpochCtrlOp>(num_epochs_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Function to validate the parameters for EpochCtrlNode
|
||||
Status EpochCtrlNode::ValidateParams() {
|
||||
if (num_epochs_ <= 0 && num_epochs_ != -1) {
|
||||
std::string err_msg =
|
||||
"EpochCtrlNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (children_.size() != 1 || children_[0] == nullptr) {
|
||||
std::string err_msg = "Internal error: epoch control node should have one child node";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class EpochCtrlNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs);
|
||||
|
||||
/// \brief Destructor
|
||||
~EpochCtrlNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kEpochCtrlNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
int32_t num_epochs_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_
|
|
@ -21,7 +21,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/filter_op.h"
|
||||
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -31,7 +31,16 @@ namespace dataset {
|
|||
FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate,
|
||||
std::vector<std::string> input_columns)
|
||||
: predicate_(predicate), input_columns_(input_columns) {
|
||||
this->children.push_back(child);
|
||||
this->AddChild(child);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> FilterNode::Copy() {
|
||||
auto node = std::make_shared<FilterNode>(nullptr, predicate_, input_columns_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void FilterNode::Print(std::ostream &out) const {
|
||||
out << Name() + "(<predicate>," + "input_cols:" + PrintColumns(input_columns_) + ")";
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> FilterNode::Build() {
|
||||
|
@ -54,5 +63,17 @@ Status FilterNode::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status FilterNode::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->Visit(shared_from_base<FilterNode>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status FilterNode::AcceptAfter(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->VisitAfter(shared_from_base<FilterNode>(), modified);
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,6 +35,18 @@ class FilterNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~FilterNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kFilterNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
@ -43,6 +55,18 @@ class FilterNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<TensorOp> predicate_;
|
||||
std::vector<std::string> input_columns_;
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
|
@ -37,7 +38,18 @@ MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr
|
|||
project_columns_(project_columns),
|
||||
DatasetNode(std::move(cache)),
|
||||
callbacks_(callbacks) {
|
||||
this->children.push_back(child);
|
||||
this->AddChild(child);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> MapNode::Copy() {
|
||||
auto node = std::make_shared<MapNode>(nullptr, operations_, input_columns_, output_columns_, project_columns_, cache_,
|
||||
callbacks_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void MapNode::Print(std::ostream &out) const {
|
||||
out << Name() + "(<ops>" + ",input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) +
|
||||
",<project_cols>" + ",...)";
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() {
|
||||
|
@ -93,5 +105,16 @@ Status MapNode::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status MapNode::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->Visit(shared_from_base<MapNode>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status MapNode::AcceptAfter(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->VisitAfter(shared_from_base<MapNode>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,6 +37,18 @@ class MapNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~MapNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kMapNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
@ -45,6 +57,23 @@ class MapNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Getter of tensor operations
|
||||
/// \return Vector of operations the Map node will process
|
||||
const auto &TensorOperations() const { return operations_; }
|
||||
auto &TensorOperations() { return operations_; }
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<TensorOperation>> operations_;
|
||||
std::vector<std::string> input_columns_;
|
||||
|
|
|
@ -29,9 +29,16 @@ namespace dataset {
|
|||
// Function to build ProjectOp
|
||||
ProjectNode::ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns)
|
||||
: columns_(columns) {
|
||||
this->children.push_back(child);
|
||||
this->AddChild(child);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> ProjectNode::Copy() {
|
||||
auto node = std::make_shared<ProjectNode>(nullptr, this->columns_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void ProjectNode::Print(std::ostream &out) const { out << Name() + "(column: " + PrintColumns(columns_) + ")"; }
|
||||
|
||||
Status ProjectNode::ValidateParams() {
|
||||
if (columns_.empty()) {
|
||||
std::string err_msg = "ProjectNode: No columns are specified.";
|
||||
|
|
|
@ -34,6 +34,18 @@ class ProjectNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~ProjectNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kProjectNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -30,7 +30,16 @@ namespace dataset {
|
|||
RenameNode::RenameNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &input_columns,
|
||||
const std::vector<std::string> &output_columns)
|
||||
: input_columns_(input_columns), output_columns_(output_columns) {
|
||||
this->children.push_back(child);
|
||||
this->AddChild(child);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> RenameNode::Copy() {
|
||||
auto node = std::make_shared<RenameNode>(nullptr, input_columns_, output_columns_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void RenameNode::Print(std::ostream &out) const {
|
||||
out << Name() + "(input:" + PrintColumns(input_columns_) + ",output:" + PrintColumns(output_columns_) + ")";
|
||||
}
|
||||
|
||||
Status RenameNode::ValidateParams() {
|
||||
|
|
|
@ -35,6 +35,18 @@ class RenameNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~RenameNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kRenameNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -21,15 +21,22 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/repeat_op.h"
|
||||
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
RepeatNode::RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count) : repeat_count_(count) {
|
||||
this->children.push_back(child);
|
||||
this->AddChild(child);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> RepeatNode::Copy() {
|
||||
auto node = std::make_shared<RepeatNode>(nullptr, this->repeat_count_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void RepeatNode::Print(std::ostream &out) const { out << Name() + "(count:" + std::to_string(repeat_count_) + ")"; }
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
@ -49,5 +56,16 @@ Status RepeatNode::ValidateParams() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status RepeatNode::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->Visit(shared_from_base<RepeatNode>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status RepeatNode::AcceptAfter(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->VisitAfter(shared_from_base<RepeatNode>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -36,6 +36,18 @@ class RepeatNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~RepeatNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kRepeatNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
@ -44,6 +56,18 @@ class RepeatNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
int32_t repeat_count_;
|
||||
};
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// Constructor for RootNode
|
||||
RootNode::RootNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs) : DatasetNode(), num_epochs_(num_epochs) {
|
||||
// The root node's parent must remain nullptr. (which is set in the constructor of DatasetNode)
|
||||
AddChild(child);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> RootNode::Copy() {
|
||||
auto node = std::make_shared<RootNode>(nullptr, num_epochs_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void RootNode::Print(std::ostream &out) const { out << Name(); }
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> RootNode::Build() {
|
||||
// root node doesn't build a runtime Op. this function should return Status::Error when called.
|
||||
return {};
|
||||
}
|
||||
|
||||
// Function to validate the parameters for RootNode
|
||||
Status RootNode::ValidateParams() {
|
||||
if (num_epochs_ <= 0 && num_epochs_ != -1) {
|
||||
std::string err_msg =
|
||||
"RootNode: num_epochs should be either -1 or positive integer, num_epochs: " + std::to_string(num_epochs_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (parent_ != nullptr) {
|
||||
std::string err_msg = "Internal error: root node should not have a parent";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (children_.size() != 1) {
|
||||
std::string err_msg = "Internal error: root node should have one child node";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (children_[0] == nullptr) {
|
||||
std::string err_msg = "Internal error: root node's child is a null pointer";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status RootNode::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->Visit(shared_from_base<RootNode>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status RootNode::AcceptAfter(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->VisitAfter(shared_from_base<RootNode>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ROOT_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ROOT_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class RootNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
RootNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs);
|
||||
|
||||
/// \brief Destructor
|
||||
~RootNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kRootNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Getter of number of epochs
|
||||
int32_t num_epochs() { return num_epochs_; }
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
int32_t num_epochs_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ROOT_NODE_H_
|
|
@ -29,7 +29,17 @@ namespace dataset {
|
|||
// Constructor for ShuffleNode
|
||||
ShuffleNode::ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch)
|
||||
: shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {
|
||||
this->children.push_back(child);
|
||||
this->AddChild(child);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> ShuffleNode::Copy() {
|
||||
auto node = std::make_shared<ShuffleNode>(nullptr, shuffle_size_, reset_every_epoch_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void ShuffleNode::Print(std::ostream &out) const {
|
||||
out << Name() + "(shuffle_size:" + std::to_string(shuffle_size_) +
|
||||
",reset_every_epoch:" + (reset_every_epoch_ ? "true" : "false") + ")";
|
||||
}
|
||||
|
||||
// Function to build the ShuffleOp
|
||||
|
|
|
@ -34,6 +34,18 @@ class ShuffleNode : public DatasetNode {
|
|||
|
||||
~ShuffleNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kShuffleNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
|
|
@ -27,10 +27,15 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
|
||||
// Constructor for SkipNode
|
||||
SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) {
|
||||
this->children.push_back(child);
|
||||
SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) { this->AddChild(child); }
|
||||
|
||||
std::shared_ptr<DatasetNode> SkipNode::Copy() {
|
||||
auto node = std::make_shared<SkipNode>(nullptr, skip_count_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void SkipNode::Print(std::ostream &out) const { out << Name() + "(skip_count:" + std::to_string(skip_count_) + ")"; }
|
||||
|
||||
// Function to build the SkipOp
|
||||
std::vector<std::shared_ptr<DatasetOp>> SkipNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
|
|
|
@ -34,6 +34,18 @@ class SkipNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~SkipNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kSkipNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -32,13 +32,23 @@ namespace dataset {
|
|||
AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema,
|
||||
const std::vector<std::string> &column_names, bool decode,
|
||||
const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache)
|
||||
: DatasetNode(std::move(cache)),
|
||||
: MappableSourceNode(std::move(cache)),
|
||||
dataset_dir_(dataset_dir),
|
||||
schema_path_(data_schema),
|
||||
column_names_(column_names),
|
||||
decode_(decode),
|
||||
sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> AlbumNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
||||
auto node = std::make_shared<AlbumNode>(dataset_dir_, schema_path_, column_names_, decode_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void AlbumNode::Print(std::ostream &out) const {
|
||||
out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")";
|
||||
}
|
||||
|
||||
Status AlbumNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumNode", dataset_dir_));
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class AlbumNode : public DatasetNode {
|
||||
class AlbumNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
AlbumNode(const std::string &dataset_dir, const std::string &data_schema,
|
||||
|
@ -36,6 +36,18 @@ class AlbumNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~AlbumNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kAlbumNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create a runtime dataset op object from this class
|
||||
/// \return shared pointer to the newly created DatasetOp
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -31,13 +31,23 @@ namespace dataset {
|
|||
CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage,
|
||||
const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
|
||||
const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache)
|
||||
: DatasetNode(std::move(cache)),
|
||||
: MappableSourceNode(std::move(cache)),
|
||||
dataset_dir_(dataset_dir),
|
||||
usage_(usage),
|
||||
sampler_(sampler),
|
||||
decode_(decode),
|
||||
extensions_(extensions) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> CelebANode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
||||
auto node = std::make_shared<CelebANode>(dataset_dir_, usage_, sampler, decode_, extensions_, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void CelebANode::Print(std::ostream &out) const {
|
||||
out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")";
|
||||
}
|
||||
|
||||
Status CelebANode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_));
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class CelebANode : public DatasetNode {
|
||||
class CelebANode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler,
|
||||
|
@ -37,6 +37,18 @@ class CelebANode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~CelebANode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kCelebANode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -30,7 +30,17 @@ namespace dataset {
|
|||
// Constructor for Cifar100Node
|
||||
Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage,
|
||||
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache)
|
||||
: DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> Cifar100Node::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
||||
auto node = std::make_shared<Cifar100Node>(dataset_dir_, usage_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void Cifar100Node::Print(std::ostream &out) const {
|
||||
out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")";
|
||||
}
|
||||
|
||||
Status Cifar100Node::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_));
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class Cifar100Node : public DatasetNode {
|
||||
class Cifar100Node : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
|
||||
|
@ -35,6 +35,18 @@ class Cifar100Node : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~Cifar100Node() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kCifar100Node; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -30,7 +30,17 @@ namespace dataset {
|
|||
// Constructor for Cifar10Node
|
||||
Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache)
|
||||
: DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> Cifar10Node::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
||||
auto node = std::make_shared<Cifar10Node>(dataset_dir_, usage_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void Cifar10Node::Print(std::ostream &out) const {
|
||||
out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ")";
|
||||
}
|
||||
|
||||
Status Cifar10Node::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_));
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class Cifar10Node : public DatasetNode {
|
||||
class Cifar10Node : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
|
||||
|
@ -35,6 +35,18 @@ class Cifar10Node : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~Cifar10Node() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kCifar10Node; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace dataset {
|
|||
// Constructor for CLUENode
|
||||
CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, std::string usage, int64_t num_samples,
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
|
||||
: DatasetNode(std::move(cache)),
|
||||
: NonMappableSourceNode(std::move(cache)),
|
||||
dataset_files_(clue_files),
|
||||
task_(task),
|
||||
usage_(usage),
|
||||
|
@ -41,6 +41,17 @@ CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task,
|
|||
num_shards_(num_shards),
|
||||
shard_id_(shard_id) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> CLUENode::Copy() {
|
||||
auto node =
|
||||
std::make_shared<CLUENode>(dataset_files_, task_, usage_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void CLUENode::Print(std::ostream &out) const {
|
||||
out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ",..." +
|
||||
",num_shards:" + std::to_string(num_shards_) + ",shard_id:" + std::to_string(shard_id_) + ")";
|
||||
}
|
||||
|
||||
Status CLUENode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_));
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ namespace dataset {
|
|||
|
||||
/// \class CLUENode
|
||||
/// \brief A Dataset derived class to represent CLUE dataset
|
||||
class CLUENode : public DatasetNode {
|
||||
class CLUENode : public NonMappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CLUENode(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples,
|
||||
|
@ -37,6 +37,18 @@ class CLUENode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~CLUENode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kCLUENode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -30,13 +30,21 @@ namespace dataset {
|
|||
// Constructor for CocoNode
|
||||
CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
|
||||
const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
|
||||
: DatasetNode(std::move(cache)),
|
||||
: MappableSourceNode(std::move(cache)),
|
||||
dataset_dir_(dataset_dir),
|
||||
annotation_file_(annotation_file),
|
||||
task_(task),
|
||||
decode_(decode),
|
||||
sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> CocoNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
||||
auto node = std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void CocoNode::Print(std::ostream &out) const { out << Name(); }
|
||||
|
||||
Status CocoNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_));
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class CocoNode : public DatasetNode {
|
||||
class CocoNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
|
||||
|
@ -35,6 +35,18 @@ class CocoNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~CocoNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kCocoNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -33,7 +33,7 @@ CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim,
|
|||
const std::vector<std::shared_ptr<CsvBase>> &column_defaults,
|
||||
const std::vector<std::string> &column_names, int64_t num_samples, ShuffleMode shuffle,
|
||||
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
|
||||
: DatasetNode(std::move(cache)),
|
||||
: NonMappableSourceNode(std::move(cache)),
|
||||
dataset_files_(csv_files),
|
||||
field_delim_(field_delim),
|
||||
column_defaults_(column_defaults),
|
||||
|
@ -43,6 +43,17 @@ CSVNode::CSVNode(const std::vector<std::string> &csv_files, char field_delim,
|
|||
num_shards_(num_shards),
|
||||
shard_id_(shard_id) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> CSVNode::Copy() {
|
||||
auto node = std::make_shared<CSVNode>(dataset_files_, field_delim_, column_defaults_, column_names_, num_samples_,
|
||||
shuffle_, num_shards_, shard_id_, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void CSVNode::Print(std::ostream &out) const {
|
||||
out << Name() + "(cache:" + ((cache_ != nullptr) ? "true" : "false") + ",..." +
|
||||
",num_shards:" + std::to_string(num_shards_) + ",shard_id:" + std::to_string(shard_id_) + ")";
|
||||
}
|
||||
|
||||
Status CSVNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CSVNode", dataset_files_));
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ class CsvRecord : public CsvBase {
|
|||
T value;
|
||||
};
|
||||
|
||||
class CSVNode : public DatasetNode {
|
||||
class CSVNode : public NonMappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CSVNode(const std::vector<std::string> &dataset_files, char field_delim,
|
||||
|
@ -58,6 +58,18 @@ class CSVNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~CSVNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kCSVNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -28,7 +28,19 @@ namespace dataset {
|
|||
|
||||
GeneratorNode::GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names,
|
||||
const std::vector<DataType> &column_types)
|
||||
: generator_function_(generator_function), column_names_(column_names), column_types_(column_types) {}
|
||||
: MappableSourceNode(),
|
||||
generator_function_(generator_function),
|
||||
column_names_(column_names),
|
||||
column_types_(column_types) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> GeneratorNode::Copy() {
|
||||
auto node = std::make_shared<GeneratorNode>(generator_function_, column_names_, column_types_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void GeneratorNode::Print(std::ostream &out) const {
|
||||
out << Name() + "(<func>:" + ",columns:" + PrintColumns(column_names_) + ",<col_types>)";
|
||||
}
|
||||
|
||||
GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema)
|
||||
: generator_function_(generator_function), schema_(schema) {}
|
||||
|
|
|
@ -26,10 +26,9 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
/// \class GeneratorNode
|
||||
/// \brief A Dataset derived class to represent GeneratorNode dataset
|
||||
class GeneratorNode : public DatasetNode {
|
||||
class GeneratorNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names,
|
||||
|
@ -41,6 +40,18 @@ class GeneratorNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~GeneratorNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kGeneratorNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -33,13 +33,24 @@ ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shar
|
|||
bool recursive, std::set<std::string> extensions,
|
||||
std::map<std::string, int32_t> class_indexing,
|
||||
std::shared_ptr<DatasetCache> cache = nullptr)
|
||||
: dataset_dir_(dataset_dir),
|
||||
: MappableSourceNode(std::move(cache)),
|
||||
dataset_dir_(dataset_dir),
|
||||
decode_(decode),
|
||||
sampler_(sampler),
|
||||
recursive_(recursive),
|
||||
class_indexing_(class_indexing),
|
||||
exts_(extensions),
|
||||
DatasetNode(std::move(cache)) {}
|
||||
exts_(extensions) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> ImageFolderNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
||||
auto node =
|
||||
std::make_shared<ImageFolderNode>(dataset_dir_, decode_, sampler, recursive_, exts_, class_indexing_, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void ImageFolderNode::Print(std::ostream &out) const {
|
||||
out << Name() + "(path:" + dataset_dir_ + ",decode:" + (decode_ ? "true" : "false") + ",...)";
|
||||
}
|
||||
|
||||
Status ImageFolderNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_));
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace dataset {
|
|||
|
||||
/// \class ImageFolderNode
|
||||
/// \brief A Dataset derived class to represent ImageFolder dataset
|
||||
class ImageFolderNode : public DatasetNode {
|
||||
class ImageFolderNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, bool recursive,
|
||||
|
@ -41,6 +41,18 @@ class ImageFolderNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~ImageFolderNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kImageFolderNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -32,13 +32,30 @@ ManifestNode::ManifestNode(const std::string &dataset_file, const std::string &u
|
|||
const std::shared_ptr<SamplerObj> &sampler,
|
||||
const std::map<std::string, int32_t> &class_indexing, bool decode,
|
||||
std::shared_ptr<DatasetCache> cache)
|
||||
: DatasetNode(std::move(cache)),
|
||||
: MappableSourceNode(std::move(cache)),
|
||||
dataset_file_(dataset_file),
|
||||
usage_(usage),
|
||||
decode_(decode),
|
||||
class_index_(class_indexing),
|
||||
sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> ManifestNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
||||
auto node = std::make_shared<ManifestNode>(dataset_file_, usage_, sampler, class_index_, decode_, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void ManifestNode::Print(std::ostream &out) const {
|
||||
out << Name() + "(file:" + dataset_file_;
|
||||
if (sampler_ != nullptr) {
|
||||
out << ",sampler";
|
||||
}
|
||||
if (cache_ != nullptr) {
|
||||
out << ",cache";
|
||||
}
|
||||
out << ")";
|
||||
}
|
||||
|
||||
Status ManifestNode::ValidateParams() {
|
||||
std::vector<char> forbidden_symbols = {':', '*', '?', '"', '<', '>', '|', '`', '&', '\'', ';'};
|
||||
for (char c : dataset_file_) {
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class ManifestNode : public DatasetNode {
|
||||
class ManifestNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
ManifestNode(const std::string &dataset_file, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler,
|
||||
|
@ -36,6 +36,18 @@ class ManifestNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~ManifestNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kManifestNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -30,7 +30,8 @@ namespace dataset {
|
|||
|
||||
MindDataNode::MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list,
|
||||
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded)
|
||||
: dataset_file_(std::string()),
|
||||
: MappableSourceNode(),
|
||||
dataset_file_(std::string()),
|
||||
dataset_files_(dataset_files),
|
||||
search_for_pattern_(false),
|
||||
columns_list_(columns_list),
|
||||
|
@ -41,7 +42,8 @@ MindDataNode::MindDataNode(const std::vector<std::string> &dataset_files, const
|
|||
|
||||
MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<std::string> &columns_list,
|
||||
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded)
|
||||
: dataset_file_(dataset_file),
|
||||
: MappableSourceNode(),
|
||||
dataset_file_(dataset_file),
|
||||
dataset_files_({}),
|
||||
search_for_pattern_(true),
|
||||
columns_list_(columns_list),
|
||||
|
@ -50,6 +52,19 @@ MindDataNode::MindDataNode(const std::string &dataset_file, const std::vector<st
|
|||
sample_bytes_({}),
|
||||
num_padded_(num_padded) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> MindDataNode::Copy() {
|
||||
std::shared_ptr<MindDataNode> node;
|
||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
||||
if (dataset_files_.empty()) {
|
||||
node = std::make_shared<MindDataNode>(dataset_file_, columns_list_, sampler, padded_sample_, num_padded_);
|
||||
} else {
|
||||
node = std::make_shared<MindDataNode>(dataset_files_, columns_list_, sampler, padded_sample_, num_padded_);
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
void MindDataNode::Print(std::ostream &out) const { out << Name() + "(file:" + dataset_file_ + ",...)"; }
|
||||
|
||||
Status MindDataNode::ValidateParams() {
|
||||
if (!search_for_pattern_ && dataset_files_.size() > 4096) {
|
||||
std::string err_msg =
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class MindDataNode : public DatasetNode {
|
||||
class MindDataNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
MindDataNode(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list,
|
||||
|
@ -40,6 +40,18 @@ class MindDataNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~MindDataNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kMindDataNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -29,7 +29,15 @@ namespace dataset {
|
|||
|
||||
MnistNode::MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache)
|
||||
: DatasetNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> MnistNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
||||
auto node = std::make_shared<MnistNode>(dataset_dir_, usage_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void MnistNode::Print(std::ostream &out) const { out << Name(); }
|
||||
|
||||
Status MnistNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_));
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class MnistNode : public DatasetNode {
|
||||
class MnistNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
MnistNode(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler,
|
||||
|
@ -35,6 +35,18 @@ class MnistNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~MnistNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kMnistNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -27,6 +27,18 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
std::shared_ptr<DatasetNode> RandomNode::Copy() {
|
||||
std::shared_ptr<RandomNode> node;
|
||||
if (schema_ != nullptr) {
|
||||
node = std::make_shared<RandomNode>(total_rows_, schema_, columns_list_, cache_);
|
||||
} else {
|
||||
node = std::make_shared<RandomNode>(total_rows_, schema_path_, columns_list_, cache_);
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
void RandomNode::Print(std::ostream &out) const { out << Name() + "(num_row:" + std::to_string(total_rows_) + ",...)"; }
|
||||
|
||||
// ValidateParams for RandomNode
|
||||
Status RandomNode::ValidateParams() {
|
||||
if (total_rows_ < 0) {
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class RandomNode : public DatasetNode {
|
||||
class RandomNode : public NonMappableSourceNode {
|
||||
public:
|
||||
// Some constants to provide limits to random generation.
|
||||
static constexpr int32_t kMaxNumColumns = 4;
|
||||
|
@ -37,7 +37,7 @@ class RandomNode : public DatasetNode {
|
|||
/// \brief Constructor
|
||||
RandomNode(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema, const std::vector<std::string> &columns_list,
|
||||
std::shared_ptr<DatasetCache> cache)
|
||||
: DatasetNode(std::move(cache)),
|
||||
: NonMappableSourceNode(std::move(cache)),
|
||||
total_rows_(total_rows),
|
||||
schema_path_(""),
|
||||
schema_(std::move(schema)),
|
||||
|
@ -46,14 +46,27 @@ class RandomNode : public DatasetNode {
|
|||
/// \brief Constructor
|
||||
RandomNode(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list,
|
||||
std::shared_ptr<DatasetCache> cache)
|
||||
: DatasetNode(std::move(cache)),
|
||||
: NonMappableSourceNode(std::move(cache)),
|
||||
total_rows_(total_rows),
|
||||
schema_path_(schema_path),
|
||||
schema_(nullptr),
|
||||
columns_list_(columns_list) {}
|
||||
|
||||
/// \brief Destructor
|
||||
~RandomNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kRandomNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -31,13 +31,23 @@ namespace dataset {
|
|||
// Constructor for TextFileNode
|
||||
TextFileNode::TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle,
|
||||
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
|
||||
: DatasetNode(std::move(cache)),
|
||||
: NonMappableSourceNode(std::move(cache)),
|
||||
dataset_files_(dataset_files),
|
||||
num_samples_(num_samples),
|
||||
shuffle_(shuffle),
|
||||
num_shards_(num_shards),
|
||||
shard_id_(shard_id) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> TextFileNode::Copy() {
|
||||
auto node = std::make_shared<TextFileNode>(dataset_files_, num_samples_, shuffle_, num_shards_, shard_id_, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void TextFileNode::Print(std::ostream &out) const {
|
||||
out << Name() + "(file:..." + ",num_shards:" + std::to_string(num_shards_) +
|
||||
",shard_id:" + std::to_string(shard_id_) + ",cache:" + ((cache_ != nullptr) ? "true" : "false") + ",...)";
|
||||
}
|
||||
|
||||
Status TextFileNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TextFileNode", dataset_files_));
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ namespace dataset {
|
|||
|
||||
/// \class TextFileNode
|
||||
/// \brief A Dataset derived class to represent TextFile dataset
|
||||
class TextFileNode : public DatasetNode {
|
||||
class TextFileNode : public NonMappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
TextFileNode(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
|
||||
|
@ -37,6 +37,18 @@ class TextFileNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~TextFileNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kTextFileNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -30,6 +30,23 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
std::shared_ptr<DatasetNode> TFRecordNode::Copy() {
|
||||
std::shared_ptr<TFRecordNode> node;
|
||||
if (schema_obj_ != nullptr) {
|
||||
node = std::make_shared<TFRecordNode>(dataset_files_, schema_obj_, columns_list_, num_samples_, shuffle_,
|
||||
num_shards_, shard_id_, shard_equal_rows_, cache_);
|
||||
} else {
|
||||
node = std::make_shared<TFRecordNode>(dataset_files_, schema_path_, columns_list_, num_samples_, shuffle_,
|
||||
num_shards_, shard_id_, shard_equal_rows_, cache_);
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
void TFRecordNode::Print(std::ostream &out) const {
|
||||
out << Name() + "(num_samples:" + std::to_string(num_samples_) + ",num_shards:" + std::to_string(num_shards_) +
|
||||
",shard_id:" + std::to_string(shard_id_) + ",...)";
|
||||
}
|
||||
|
||||
// Validator for TFRecordNode
|
||||
Status TFRecordNode::ValidateParams() {
|
||||
if (dataset_files_.empty()) {
|
||||
|
|
|
@ -29,14 +29,14 @@ namespace dataset {
|
|||
|
||||
/// \class TFRecordNode
|
||||
/// \brief A Dataset derived class to represent TFRecord dataset
|
||||
class TFRecordNode : public DatasetNode {
|
||||
class TFRecordNode : public NonMappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
/// \note Parameter 'schema' is the path to the schema file
|
||||
TFRecordNode(const std::vector<std::string> &dataset_files, std::string schema,
|
||||
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
|
||||
int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache)
|
||||
: DatasetNode(std::move(cache)),
|
||||
: NonMappableSourceNode(std::move(cache)),
|
||||
dataset_files_(dataset_files),
|
||||
schema_path_(schema),
|
||||
columns_list_(columns_list),
|
||||
|
@ -51,7 +51,7 @@ class TFRecordNode : public DatasetNode {
|
|||
TFRecordNode(const std::vector<std::string> &dataset_files, std::shared_ptr<SchemaObj> schema,
|
||||
const std::vector<std::string> &columns_list, int64_t num_samples, ShuffleMode shuffle,
|
||||
int32_t num_shards, int32_t shard_id, bool shard_equal_rows, std::shared_ptr<DatasetCache> cache)
|
||||
: DatasetNode(std::move(cache)),
|
||||
: NonMappableSourceNode(std::move(cache)),
|
||||
dataset_files_(dataset_files),
|
||||
schema_obj_(schema),
|
||||
columns_list_(columns_list),
|
||||
|
@ -64,6 +64,18 @@ class TFRecordNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~TFRecordNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kTFRecordNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace dataset {
|
|||
VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage,
|
||||
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache)
|
||||
: DatasetNode(std::move(cache)),
|
||||
: MappableSourceNode(std::move(cache)),
|
||||
dataset_dir_(dataset_dir),
|
||||
task_(task),
|
||||
usage_(usage),
|
||||
|
@ -40,6 +40,14 @@ VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const
|
|||
decode_(decode),
|
||||
sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> VOCNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = sampler_ == nullptr ? nullptr : sampler_->Copy();
|
||||
auto node = std::make_shared<VOCNode>(dataset_dir_, task_, usage_, class_index_, decode_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void VOCNode::Print(std::ostream &out) const { out << Name(); }
|
||||
|
||||
Status VOCNode::ValidateParams() {
|
||||
Path dir(dataset_dir_);
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class VOCNode : public DatasetNode {
|
||||
class VOCNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
VOCNode(const std::string &dataset_dir, const std::string &task, const std::string &usage,
|
||||
|
@ -37,6 +37,18 @@ class VOCNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~VOCNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kVOCNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -29,7 +29,16 @@ namespace dataset {
|
|||
// Constructor for SyncWaitNode
|
||||
SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, py::function callback)
|
||||
: condition_name_(condition_name), callback_(callback) {
|
||||
this->children.push_back(child);
|
||||
this->AddChild(child);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> SyncWaitNode::Copy() {
|
||||
auto node = std::make_shared<SyncWaitNode>(nullptr, condition_name_, callback_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void SyncWaitNode::Print(std::ostream &out) const {
|
||||
out << Name() + "(cond_name:" + condition_name_ + "<pyfunc>" + ")";
|
||||
}
|
||||
|
||||
// Function to build the BarrierOp
|
||||
|
|
|
@ -36,6 +36,18 @@ class SyncWaitNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~SyncWaitNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kSyncWaitNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -27,10 +27,15 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
|
||||
// Constructor for TakeNode
|
||||
TakeNode::TakeNode(std::shared_ptr<DatasetNode> child, int32_t count) : take_count_(count) {
|
||||
this->children.push_back(child);
|
||||
TakeNode::TakeNode(std::shared_ptr<DatasetNode> child, int32_t count) : take_count_(count) { this->AddChild(child); }
|
||||
|
||||
std::shared_ptr<DatasetNode> TakeNode::Copy() {
|
||||
auto node = std::make_shared<TakeNode>(nullptr, take_count_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void TakeNode::Print(std::ostream &out) const { out << Name() + "(num_rows:" + std::to_string(take_count_) + ")"; }
|
||||
|
||||
// Function to build the TakeOp
|
||||
std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
|
|
|
@ -34,6 +34,18 @@ class TakeNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~TakeNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kTakeNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -39,7 +40,19 @@ TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, std::string queue
|
|||
total_batch_(total_batch),
|
||||
create_data_info_queue_(create_data_info_queue),
|
||||
device_id_(0) {
|
||||
this->children.push_back(child);
|
||||
this->AddChild(child);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> TransferNode::Copy() {
|
||||
auto node = std::make_shared<TransferNode>(nullptr, queue_name_, device_type_, send_epoch_end_, total_batch_,
|
||||
create_data_info_queue_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void TransferNode::Print(std::ostream &out) const {
|
||||
out << Name() + "(prefetch_size:" + std::to_string(prefetch_size_) +
|
||||
",send_epoch_end:" + (send_epoch_end_ ? "true" : "false") + ",total_batch:" + std::to_string(total_batch_) +
|
||||
")";
|
||||
}
|
||||
|
||||
// Validator for TransferNode
|
||||
|
@ -94,5 +107,16 @@ std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() {
|
|||
return node_ops;
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status TransferNode::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->Visit(shared_from_base<TransferNode>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status TransferNode::AcceptAfter(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->VisitAfter(shared_from_base<TransferNode>(), modified);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,6 +35,18 @@ class TransferNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~TransferNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kTransferNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
@ -43,6 +55,20 @@ class TransferNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
static Status get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id);
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
||||
|
||||
private:
|
||||
std::string queue_name_;
|
||||
int32_t device_id_;
|
||||
|
|
|
@ -21,30 +21,36 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/zip_op.h"
|
||||
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
ZipNode::ZipNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) : datasets_(datasets) {
|
||||
for (auto dataset : datasets_) {
|
||||
this->children.push_back(dataset);
|
||||
}
|
||||
ZipNode::ZipNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) {
|
||||
for (auto const &child : datasets) AddChild(child);
|
||||
}
|
||||
|
||||
std::shared_ptr<DatasetNode> ZipNode::Copy() {
|
||||
std::vector<std::shared_ptr<DatasetNode>> empty_vector;
|
||||
empty_vector.clear();
|
||||
auto node = std::make_shared<ZipNode>(empty_vector);
|
||||
return node;
|
||||
}
|
||||
|
||||
void ZipNode::Print(std::ostream &out) const { out << Name(); }
|
||||
|
||||
Status ZipNode::ValidateParams() {
|
||||
if (datasets_.empty()) {
|
||||
std::string err_msg = "ZipNode: datasets to zip are not specified.";
|
||||
if (children_.size() < 2) {
|
||||
std::string err_msg = "ZipNode: input datasets are not specified.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) {
|
||||
std::string err_msg = "ZipNode: zip datasets should not be null.";
|
||||
if (find(children_.begin(), children_.end(), nullptr) != children_.end()) {
|
||||
std::string err_msg = "ZipNode: input datasets should not be null.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -56,5 +62,17 @@ std::vector<std::shared_ptr<DatasetOp>> ZipNode::Build() {
|
|||
return node_ops;
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status ZipNode::Accept(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->Visit(shared_from_base<ZipNode>(), modified);
|
||||
}
|
||||
|
||||
// Visitor accepting method for NodePass
|
||||
Status ZipNode::AcceptAfter(NodePass *p, bool *modified) {
|
||||
// Downcast shared pointer then call visitor
|
||||
return p->VisitAfter(shared_from_base<ZipNode>(), modified);
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,6 +34,18 @@ class ZipNode : public DatasetNode {
|
|||
/// \brief Destructor
|
||||
~ZipNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
std::string Name() const override { return kZipNode; }
|
||||
|
||||
/// \brief Print the description
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object
|
||||
/// \return A shared pointer to the new copy
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
@ -42,8 +54,17 @@ class ZipNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<DatasetNode>> datasets_;
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status Accept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Base-class override for accepting NodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status AcceptAfter(NodePass *p, bool *modified) override;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -22,10 +22,12 @@
|
|||
#endif
|
||||
#include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/concat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/filter_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/project_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/rename_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#ifdef ENABLE_PYTHON
|
||||
|
@ -34,34 +36,6 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/coco_node.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
#endif
|
||||
#ifdef ENABLE_PYTHON
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h"
|
||||
#endif
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/voc_node.h"
|
||||
|
||||
//////////////////////////////////
|
||||
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
|
||||
|
@ -113,7 +87,12 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
|
||||
// Driver method for TreePass
|
||||
Status TreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); }
|
||||
Status TreePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
|
||||
if (root_ir == nullptr || modified == nullptr) {
|
||||
return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass");
|
||||
}
|
||||
return this->RunOnTree(root_ir, modified);
|
||||
}
|
||||
|
||||
// Driver method for NodePass
|
||||
Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
|
||||
|
@ -132,15 +111,23 @@ Status NodePass::Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
|
|||
|
||||
// Helper function to perform DFS visit
|
||||
Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) {
|
||||
RETURN_IF_NOT_OK(node_ir->Accept(this, modified));
|
||||
bool m = false;
|
||||
|
||||
RETURN_IF_NOT_OK(node_ir->Accept(this, &m));
|
||||
*modified |= m;
|
||||
for (const auto &c : node_ir->Children()) {
|
||||
RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified));
|
||||
RETURN_IF_NOT_OK(this->DFSNodeVisit(c, &m));
|
||||
*modified |= m;
|
||||
}
|
||||
return node_ir->AcceptAfter(this, modified);
|
||||
RETURN_IF_NOT_OK(node_ir->AcceptAfter(this, &m));
|
||||
*modified |= m;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Helper function to perform BFS visit
|
||||
Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modified) {
|
||||
bool m = false;
|
||||
|
||||
// Initialize bfs queue with root
|
||||
std::queue<std::shared_ptr<DatasetNode>> bfsQueue;
|
||||
bfsQueue.push(node_ir);
|
||||
|
@ -152,7 +139,8 @@ Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modifi
|
|||
bfsQueue.pop();
|
||||
|
||||
// Run node pass
|
||||
RETURN_IF_NOT_OK(curNode->Accept(this, modified));
|
||||
RETURN_IF_NOT_OK(curNode->Accept(this, &m));
|
||||
*modified |= m;
|
||||
|
||||
// Push children into bfs queue
|
||||
for (const auto &c : curNode->Children()) {
|
||||
|
@ -162,331 +150,119 @@ Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetNode> node_ir, bool *modifi
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// For datasetops IR
|
||||
// For non-leaf IR node
|
||||
Status NodePass::Visit(std::shared_ptr<BatchNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<BatchNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status NodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
#endif
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<ConcatNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<FilterNode> node, bool *modified) {
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
Status NodePass::VisitAfter(std::shared_ptr<FilterNode> node, bool *modified) {
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
Status NodePass::Visit(std::shared_ptr<MapNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<MapNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<ProjectNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<RenameNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<RenameNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<RepeatNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<RootNode> node, bool *modified) {
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
Status NodePass::VisitAfter(std::shared_ptr<RootNode> node, bool *modified) {
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
Status NodePass::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<SkipNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<SkipNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<TakeNode> node, bool *modified) {
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
Status NodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *modified) {
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
Status NodePass::Visit(std::shared_ptr<TransferNode> node, bool *modified) {
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
Status NodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) {
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
Status NodePass::Visit(std::shared_ptr<ZipNode> node, bool *modified) {
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
Status NodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *modified) {
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
#ifdef ENABLE_PYTHON
|
||||
Status NodePass::Visit(std::shared_ptr<SyncWaitNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
#endif
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<TakeNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<TakeNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<TransferNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<ZipNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<ZipNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
// For datasetops/source IR
|
||||
Status NodePass::Visit(std::shared_ptr<AlbumNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<AlbumNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<CelebANode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<CelebANode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<Cifar100Node> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<Cifar100Node> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<Cifar10Node> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<Cifar10Node> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status NodePass::Visit(std::shared_ptr<CLUENode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
Status NodePass::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<CLUENode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
Status NodePass::VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) {
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
#endif
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<CocoNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
// For leaf IR Node
|
||||
Status NodePass::Visit(std::shared_ptr<SourceNode> node, bool *modified) {
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<CocoNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status NodePass::Visit(std::shared_ptr<CSVNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<CSVNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
Status NodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<GeneratorNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
#endif
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<ImageFolderNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<ImageFolderNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<ManifestNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<ManifestNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status NodePass::Visit(std::shared_ptr<MindDataNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<MindDataNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
#endif
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<MnistNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<MnistNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<RandomNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<RandomNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status NodePass::Visit(std::shared_ptr<TextFileNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<TextFileNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status NodePass::Visit(std::shared_ptr<TFRecordNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<TFRecordNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
#endif
|
||||
|
||||
Status NodePass::Visit(std::shared_ptr<VOCNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
return Visit(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
Status NodePass::VisitAfter(std::shared_ptr<VOCNode> node, bool *modified) {
|
||||
// Fallback to base class visitor by default
|
||||
Status NodePass::VisitAfter(std::shared_ptr<SourceNode> node, bool *modified) {
|
||||
return VisitAfter(std::static_pointer_cast<DatasetNode>(node), modified);
|
||||
}
|
||||
|
||||
|
|
|
@ -26,123 +26,87 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Non-leaf IR node
|
||||
class BatchNode;
|
||||
class BucketBatchByLengthNode;
|
||||
#ifndef ENABLE_ANDROID
|
||||
class BuildSentenceVocabNode;
|
||||
#endif
|
||||
class BuildVocabNode;
|
||||
class ConcatNode;
|
||||
class FilterNode;
|
||||
class MapNode;
|
||||
class ProjectNode;
|
||||
class RenameNode;
|
||||
class RepeatNode;
|
||||
class RootNode;
|
||||
class ShuffleNode;
|
||||
class SkipNode;
|
||||
#ifdef ENABLE_PYTHON
|
||||
class SyncWaitNode;
|
||||
#endif
|
||||
class TakeNode;
|
||||
class TransferNode;
|
||||
class ZipNode;
|
||||
#ifdef ENABLE_PYTHON
|
||||
class SyncWaitNode;
|
||||
#endif
|
||||
#ifndef ENABLE_ANDROID
|
||||
class BuildSentenceVocabNode;
|
||||
#endif
|
||||
// Leaf IR node
|
||||
class AlbumNode;
|
||||
class CelebANode;
|
||||
class Cifar100Node;
|
||||
class Cifar10Node;
|
||||
#ifndef ENABLE_ANDROID
|
||||
class CLUENode;
|
||||
#endif
|
||||
class CocoNode;
|
||||
#ifndef ENABLE_ANDROID
|
||||
class CSVNode;
|
||||
#endif
|
||||
class ImageFolderNode;
|
||||
class ManifestNode;
|
||||
class MnistNode;
|
||||
class RandomNode;
|
||||
class VOCNode;
|
||||
#ifdef ENABLE_PYTHON
|
||||
class GeneratorNode;
|
||||
#endif
|
||||
class ImageFolderNode;
|
||||
class ManifestNode;
|
||||
#ifndef ENABLE_ANDROID
|
||||
class CLUENode;
|
||||
class CSVNode;
|
||||
class MindDataNode;
|
||||
#endif
|
||||
class MnistNode;
|
||||
class RandomNode;
|
||||
#ifndef ENABLE_ANDROID
|
||||
class TextFileNode;
|
||||
#endif
|
||||
#ifndef ENABLE_ANDROID
|
||||
class TFRecordNode;
|
||||
#endif
|
||||
class VOCNode;
|
||||
|
||||
//////////////////////////////////
|
||||
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
|
||||
class BatchOp;
|
||||
|
||||
class MapOp;
|
||||
|
||||
class ProjectOp;
|
||||
|
||||
class RenameOp;
|
||||
|
||||
class SkipOp;
|
||||
|
||||
class ShuffleOp;
|
||||
|
||||
class AlbumOp;
|
||||
|
||||
class RandomDataOp;
|
||||
|
||||
class RepeatOp;
|
||||
|
||||
class TakeOp;
|
||||
|
||||
class ZipOp;
|
||||
|
||||
class DeviceQueueOp;
|
||||
|
||||
class ImageFolderOp;
|
||||
|
||||
class MnistOp;
|
||||
|
||||
class ManifestOp;
|
||||
|
||||
class CifarOp;
|
||||
|
||||
class VOCOp;
|
||||
|
||||
class CocoOp;
|
||||
|
||||
class CelebAOp;
|
||||
|
||||
class EpochCtrlOp;
|
||||
|
||||
class BuildVocabOp;
|
||||
|
||||
class ConcatOp;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
class MindRecordOp;
|
||||
|
||||
class TFReaderOp;
|
||||
|
||||
class CacheOp;
|
||||
|
||||
class CacheMergeOp;
|
||||
|
||||
class CacheLookupOp;
|
||||
|
||||
class BuildSentencePieceVocabOp;
|
||||
|
||||
class ClueOp;
|
||||
|
||||
class CsvOp;
|
||||
|
||||
class TextFileOp;
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
class FilterOp;
|
||||
|
||||
class GeneratorOp;
|
||||
#endif
|
||||
//////////////////////////////////
|
||||
|
@ -175,6 +139,13 @@ class TreePass : public Pass {
|
|||
/// \param[inout] modified Indicate if the tree was modified
|
||||
Status Run(std::shared_ptr<DatasetNode> root_ir, bool *modified) final;
|
||||
|
||||
/// \brief Derived classes may implement the runOnTree function to implement tree transformation.
|
||||
/// "modified" flag needs to be set to true if tree is modified during the pass execution.
|
||||
/// \param[inout] tree The tree to operate on.
|
||||
/// \param[inout] Indicate of the tree was modified.
|
||||
/// \return Status The error code return
|
||||
virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); }
|
||||
|
||||
//////////////////////////////////
|
||||
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
|
||||
/// \brief Run the transformation pass against the execution tree.
|
||||
|
@ -191,8 +162,17 @@ class TreePass : public Pass {
|
|||
//////////////////////////////////
|
||||
};
|
||||
|
||||
// NodePass is a basic Pass class which performs transformation on Node visiting.
|
||||
// NodePass is a base Pass class which performs transformation on node visiting.
|
||||
// NodePass implements Visitor design pattern.
|
||||
// The visiting happens twice for each node in the DFS traversal, one on the way down of the traversal,
|
||||
// and the other when all the descending nodes are visited.
|
||||
// Actual transformation is done by implementing a new derived class of NodePass.
|
||||
// The derived class will implement the method Visit()/VisitAfter() passing specified node types
|
||||
// it wants to action on them, overriding the ones defined in NodePass.
|
||||
// If the derived class wants to perform the same action on all node types,
|
||||
// it can simply implement the method Visit()/VisitAfter() passing the base class DatasetNode.
|
||||
// This is made possible by overloading the method Visit()/VisitAfter() on each node type to fall back
|
||||
// to call the Visit()/VisitAfter() in this parent NodePass class.
|
||||
class NodePass : public Pass {
|
||||
public:
|
||||
// Tree traversal order
|
||||
|
@ -223,153 +203,57 @@ class NodePass : public Pass {
|
|||
/// \return Status The error code return
|
||||
virtual Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); }
|
||||
|
||||
// For datasetops IR
|
||||
// Visit method to be overridden.
|
||||
// Note that member template can not be virtual, any node which wants to work with NodePass
|
||||
// should declare Visit of its own type and override "Accept" from DatasetNode.
|
||||
// Visit()/VisitAfter() method to be overridden.
|
||||
// These pairs of Visit()/VisitAfter() for each derived class of DatasetNode are defined here.
|
||||
// Their implementation are in .cc file to avoid adding the include files of those derived classes.
|
||||
// The implementation simply falls back to call Visit()/VisitAfter of class DatasetNode, the parent of
|
||||
// the derived classes. With this technique, the transformation classes derived from NodePass needs only to
|
||||
// implement Visit()/VisitAfter() passing DatasetNode if it wants to action on any derived classes
|
||||
// of DatasetNode in the same way.
|
||||
// Note that virtual template functions are not permitted in C++.
|
||||
//
|
||||
// Non-leaf IR node
|
||||
virtual Status Visit(std::shared_ptr<BatchNode> node, bool *modified);
|
||||
|
||||
// VisitAfter method to be overridden.
|
||||
// Note that member template can not be virtual, any node which wants to work with NodePass
|
||||
// should declare VisitAfter of its own type and override "AcceptAfter" from DatasetNode.
|
||||
virtual Status VisitAfter(std::shared_ptr<BatchNode> node, bool *modified);
|
||||
|
||||
virtual Status Visit(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<BucketBatchByLengthNode> node, bool *modified);
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
|
||||
#endif
|
||||
|
||||
virtual Status Visit(std::shared_ptr<BuildVocabNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<BuildVocabNode> node, bool *modified);
|
||||
|
||||
virtual Status Visit(std::shared_ptr<ConcatNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<ConcatNode> node, bool *modified);
|
||||
|
||||
virtual Status Visit(std::shared_ptr<FilterNode> node, bool *modified);
|
||||
virtual Status VisitAfter(std::shared_ptr<FilterNode> node, bool *modified);
|
||||
virtual Status Visit(std::shared_ptr<MapNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<MapNode> node, bool *modified);
|
||||
|
||||
virtual Status Visit(std::shared_ptr<ProjectNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<ProjectNode> node, bool *modified);
|
||||
|
||||
virtual Status Visit(std::shared_ptr<RenameNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<RenameNode> node, bool *modified);
|
||||
|
||||
virtual Status Visit(std::shared_ptr<RepeatNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified);
|
||||
|
||||
virtual Status Visit(std::shared_ptr<RootNode> node, bool *modified);
|
||||
virtual Status VisitAfter(std::shared_ptr<RootNode> node, bool *modified);
|
||||
virtual Status Visit(std::shared_ptr<ShuffleNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<ShuffleNode> node, bool *modified);
|
||||
|
||||
virtual Status Visit(std::shared_ptr<SkipNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<SkipNode> node, bool *modified);
|
||||
|
||||
virtual Status Visit(std::shared_ptr<TakeNode> node, bool *modified);
|
||||
virtual Status VisitAfter(std::shared_ptr<TakeNode> node, bool *modified);
|
||||
virtual Status Visit(std::shared_ptr<TransferNode> node, bool *modified);
|
||||
virtual Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified);
|
||||
virtual Status Visit(std::shared_ptr<ZipNode> node, bool *modified);
|
||||
virtual Status VisitAfter(std::shared_ptr<ZipNode> node, bool *modified);
|
||||
#ifdef ENABLE_PYTHON
|
||||
virtual Status Visit(std::shared_ptr<SyncWaitNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<SyncWaitNode> node, bool *modified);
|
||||
#endif
|
||||
|
||||
virtual Status Visit(std::shared_ptr<TakeNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<TakeNode> node, bool *modified);
|
||||
|
||||
virtual Status Visit(std::shared_ptr<TransferNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified);
|
||||
|
||||
virtual Status Visit(std::shared_ptr<ZipNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<ZipNode> node, bool *modified);
|
||||
|
||||
// For datasetops/source IR
|
||||
virtual Status Visit(std::shared_ptr<AlbumNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<AlbumNode> node, bool *modified);
|
||||
|
||||
virtual Status Visit(std::shared_ptr<CelebANode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<CelebANode> node, bool *modified);
|
||||
|
||||
virtual Status Visit(std::shared_ptr<Cifar100Node> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<Cifar100Node> node, bool *modified);
|
||||
|
||||
virtual Status Visit(std::shared_ptr<Cifar10Node> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<Cifar10Node> node, bool *modified);
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
virtual Status Visit(std::shared_ptr<CLUENode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<CLUENode> node, bool *modified);
|
||||
virtual Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
|
||||
virtual Status VisitAfter(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified);
|
||||
#endif
|
||||
|
||||
virtual Status Visit(std::shared_ptr<CocoNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<CocoNode> node, bool *modified);
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
virtual Status Visit(std::shared_ptr<CSVNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<CSVNode> node, bool *modified);
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
virtual Status Visit(std::shared_ptr<GeneratorNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<GeneratorNode> node, bool *modified);
|
||||
#endif
|
||||
|
||||
virtual Status Visit(std::shared_ptr<ImageFolderNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<ImageFolderNode> node, bool *modified);
|
||||
|
||||
virtual Status Visit(std::shared_ptr<ManifestNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<ManifestNode> node, bool *modified);
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
virtual Status Visit(std::shared_ptr<MindDataNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<MindDataNode> node, bool *modified);
|
||||
#endif
|
||||
|
||||
virtual Status Visit(std::shared_ptr<MnistNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<MnistNode> node, bool *modified);
|
||||
|
||||
virtual Status Visit(std::shared_ptr<RandomNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<RandomNode> node, bool *modified);
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
virtual Status Visit(std::shared_ptr<TextFileNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<TextFileNode> node, bool *modified);
|
||||
#endif
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
virtual Status Visit(std::shared_ptr<TFRecordNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<TFRecordNode> node, bool *modified);
|
||||
#endif
|
||||
|
||||
virtual Status Visit(std::shared_ptr<VOCNode> node, bool *modified);
|
||||
|
||||
virtual Status VisitAfter(std::shared_ptr<VOCNode> node, bool *modified);
|
||||
// Leaf IR node
|
||||
virtual Status Visit(std::shared_ptr<SourceNode> node, bool *modified);
|
||||
virtual Status VisitAfter(std::shared_ptr<SourceNode> node, bool *modified);
|
||||
|
||||
//////////////////////////////////
|
||||
// This section of code will be removed once the migration of optimizer from DatasetOp to DatasetNode is done.
|
||||
|
@ -396,86 +280,47 @@ class NodePass : public Pass {
|
|||
// Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode
|
||||
// of its own type and override "Accept" from DatasetOp.
|
||||
virtual Status RunOnNode(std::shared_ptr<BatchOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified);
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified);
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified);
|
||||
|
||||
virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified);
|
||||
|
||||
virtual Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified);
|
||||
#endif
|
||||
//////////////////////////////////
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include "minddata/dataset/core/client.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/input_validation_pass.h"
|
||||
|
||||
|
@ -119,11 +120,16 @@ Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::sha
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_epochs) {
|
||||
num_epochs_ = num_epochs;
|
||||
Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_epochs) {
|
||||
optimize_ = true; // Always ON (temporary)
|
||||
|
||||
RETURN_UNEXPECTED_IF_NULL(root_ir);
|
||||
RETURN_UNEXPECTED_IF_NULL(input_ir);
|
||||
MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n';
|
||||
|
||||
// Copy the input IR tree and insert under the root node
|
||||
// Create a root node to host the input IR tree
|
||||
auto root_ir = std::make_shared<RootNode>(input_ir->DeepCopy(), num_epochs);
|
||||
MS_LOG(INFO) << "Plan before PrePass:" << '\n' << *root_ir << '\n';
|
||||
|
||||
// Pre-pass of the IR tree
|
||||
RETURN_IF_NOT_OK(PrePass(root_ir));
|
||||
|
@ -136,11 +142,15 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_ep
|
|||
// Post-pass of the IR tree
|
||||
RETURN_IF_NOT_OK(PostPass(root_ir));
|
||||
|
||||
MS_LOG(INFO) << "Plan after PostPass:" << '\n' << *root_ir << '\n';
|
||||
|
||||
// This will evolve in the long run
|
||||
tree_ = std::make_unique<ExecutionTree>();
|
||||
|
||||
// Build the Execution tree from the child of the root node
|
||||
std::shared_ptr<DatasetOp> root_op;
|
||||
RETURN_IF_NOT_OK(BuildExecutionTree(root_ir, &root_op));
|
||||
// We will replace input_ir with root_ir->Children()[0] once IR optimizer is in
|
||||
RETURN_IF_NOT_OK(BuildExecutionTree(input_ir, &root_op));
|
||||
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
|
||||
|
||||
if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_);
|
||||
|
|
|
@ -67,10 +67,6 @@ class TreeAdapter {
|
|||
// Optional optimizations status
|
||||
bool OptimizationEnabled() const { return optimize_; }
|
||||
|
||||
// Getter function to get the total number of epochs to be run on this tree.
|
||||
// @return total number of epochs
|
||||
int32_t num_epochs() { return num_epochs_; }
|
||||
|
||||
private:
|
||||
// This function runs a mandatory pass checking the syntax and semantics of the IR tree.
|
||||
Status PrePass(std::shared_ptr<DatasetNode> ir);
|
||||
|
|
|
@ -47,6 +47,10 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
|
|||
/// \return Shared pointers to the newly created Sampler
|
||||
virtual std::shared_ptr<SamplerRT> Build() = 0;
|
||||
|
||||
/// \brief Pure virtual function to copy a SamplerObj class
|
||||
/// \return Shared pointers to the newly copied SamplerObj
|
||||
virtual std::shared_ptr<SamplerObj> Copy() = 0;
|
||||
|
||||
/// \brief Function for derived class to get the shard id of sampler
|
||||
/// \return The shard id of the derived sampler
|
||||
virtual int64_t ShardId() { return 0; }
|
||||
|
@ -132,6 +136,11 @@ class DistributedSamplerObj : public SamplerObj {
|
|||
|
||||
std::shared_ptr<SamplerRT> Build() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override {
|
||||
return std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_,
|
||||
even_dist_);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
@ -160,6 +169,10 @@ class PKSamplerObj : public SamplerObj {
|
|||
|
||||
std::shared_ptr<SamplerRT> Build() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override {
|
||||
return std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
@ -174,9 +187,8 @@ class PKSamplerObj : public SamplerObj {
|
|||
|
||||
class PreBuiltSamplerObj : public SamplerObj {
|
||||
public:
|
||||
#ifndef ENABLE_ANDROID
|
||||
explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler);
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler);
|
||||
#endif
|
||||
|
||||
|
@ -188,6 +200,8 @@ class PreBuiltSamplerObj : public SamplerObj {
|
|||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override;
|
||||
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
|
@ -205,6 +219,8 @@ class RandomSamplerObj : public SamplerObj {
|
|||
|
||||
std::shared_ptr<SamplerRT> Build() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override { return std::make_shared<RandomSamplerObj>(replacement_, num_samples_); }
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
@ -224,6 +240,10 @@ class SequentialSamplerObj : public SamplerObj {
|
|||
|
||||
std::shared_ptr<SamplerRT> Build() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override {
|
||||
return std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
@ -243,6 +263,10 @@ class SubsetRandomSamplerObj : public SamplerObj {
|
|||
|
||||
std::shared_ptr<SamplerRT> Build() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override {
|
||||
return std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
|
||||
#endif
|
||||
|
@ -262,6 +286,10 @@ class WeightedRandomSamplerObj : public SamplerObj {
|
|||
|
||||
std::shared_ptr<SamplerRT> Build() override;
|
||||
|
||||
std::shared_ptr<SamplerObj> Copy() override {
|
||||
return std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
|
||||
}
|
||||
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -32,7 +32,10 @@ class TensorOp;
|
|||
class TensorOperation : public std::enable_shared_from_this<TensorOperation> {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
TensorOperation();
|
||||
TensorOperation() : random_op_(false) {}
|
||||
|
||||
/// \brief Constructor
|
||||
explicit TensorOperation(bool random) : random_op_(random) {}
|
||||
|
||||
/// \brief Destructor
|
||||
~TensorOperation() = default;
|
||||
|
@ -42,6 +45,13 @@ class TensorOperation : public std::enable_shared_from_this<TensorOperation> {
|
|||
virtual std::shared_ptr<TensorOp> Build() = 0;
|
||||
|
||||
virtual Status ValidateParams() = 0;
|
||||
|
||||
/// \brief Check whether the operation is deterministic.
|
||||
/// \return true if this op is a random op (returns non-deterministic result e.g. RandomCrop)
|
||||
bool IsRandomOp() const { return random_op_; }
|
||||
|
||||
protected:
|
||||
bool random_op_;
|
||||
};
|
||||
|
||||
// Helper function to validate fill value
|
||||
|
|
|
@ -427,7 +427,7 @@ Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor>
|
|||
Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst,
|
||||
std::vector<dsize_t> cur_ind, size_t cur_dim) {
|
||||
if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data
|
||||
dst->CopyLastDimAt(src, cur_ind);
|
||||
RETURN_IF_NOT_OK(dst->CopyLastDimAt(src, cur_ind));
|
||||
} else { // not the last dimension, keep doing recursion
|
||||
dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]);
|
||||
for (dsize_t i = 0; i < min_ind; i++) {
|
||||
|
|
|
@ -57,7 +57,7 @@ class RandomCropOp : public TensorOp {
|
|||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
// Function breaks out the compute function's image padding functionality and makes available to other Ops
|
||||
// Using this class as a base - restructrued to allow for RandomCropWithBBox Augmentation Op
|
||||
// Using this class as a base - re-structured to allow for RandomCropWithBBox Augmentation Op
|
||||
// @param input: Input is the original Image
|
||||
// @param pad_image: Pointer to new Padded image
|
||||
// @param t_pad_top: Total Top Padding - Based on input and value calculated in function if required
|
||||
|
|
|
@ -570,7 +570,7 @@ class WeightedRandomSampler(BuiltinSampler):
|
|||
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
|
||||
|
||||
Args:
|
||||
weights (list[float]): A sequence of weights, not necessarily summing up to 1.
|
||||
weights (list[float, int]): A sequence of weights, not necessarily summing up to 1.
|
||||
num_samples (int, optional): Number of elements to sample (default=None, all elements).
|
||||
replacement (bool): If True, put the sample ID back for the next draw (default=True).
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ SET(DE_UT_SRCS
|
|||
c_api_dataset_coco_test.cc
|
||||
c_api_dataset_config_test.cc
|
||||
c_api_dataset_csv_test.cc
|
||||
c_api_dataset_ir_node_test.cc
|
||||
c_api_dataset_iterator_test.cc
|
||||
c_api_dataset_manifest_test.cc
|
||||
c_api_dataset_minddata_test.cc
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "minddata/dataset/core/client.h"
|
||||
#include "common/common.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
#include "minddata/dataset/engine/opt/pre/getter_pass.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::LogStream;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
|
||||
class MindDataTestIRNodes : public UT::DatasetOpTesting {
|
||||
public:
|
||||
MindDataTestIRNodes() = default;
|
||||
void SetUp() override { GlobalInit(); }
|
||||
|
||||
// compare the ptr of the nodes in two trees, used to test the deep copy of nodes, will return error code
|
||||
// if (ptr1 == ptr2) does not equal to flag or the two tree has different structures (or node names are not the same)
|
||||
Status CompareTwoTrees(std::shared_ptr<DatasetNode> root1, std::shared_ptr<DatasetNode> root2, bool flag) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root1 != nullptr && root2 != nullptr, "Error in Compare, nullptr.");
|
||||
if (((root1.get() == root2.get()) != flag) || (root1->Name() != root2->Name())) {
|
||||
std::string err_msg =
|
||||
"Expect node ptr " + root1->Name() + (flag ? "==" : "!=") + root2->Name() + " but they aren't!";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
size_t num_child = root1->Children().size();
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_child == root2->Children().size(),
|
||||
root1->Name() + " has " + std::to_string(num_child) + "child, node #2 has " +
|
||||
std::to_string(root2->Children().size()) + " child.");
|
||||
|
||||
for (size_t ind = 0; ind < num_child; ind++) {
|
||||
RETURN_IF_NOT_OK(CompareTwoTrees(root1->Children()[ind], root2->Children()[ind], flag));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// print the node's name in post order
|
||||
Status PostOrderPrintTree(std::shared_ptr<DatasetNode> ir, std::string &names) {
|
||||
RETURN_UNEXPECTED_IF_NULL(ir);
|
||||
for (auto child : ir->Children()) {
|
||||
RETURN_IF_NOT_OK(PostOrderPrintTree(child, names));
|
||||
}
|
||||
names += (ir->Name() + "->");
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestIRNodes, MindDataTestSimpleDeepCopy) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestSimpleDeepCopy.";
|
||||
|
||||
auto tree1 = RandomData(44)->Repeat(2)->Project({"label"})->Shuffle(10)->Batch(2)->IRNode();
|
||||
|
||||
auto tree2 = tree1->DeepCopy();
|
||||
std::string tree_1_names, tree_2_names;
|
||||
|
||||
ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names));
|
||||
ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names));
|
||||
|
||||
// expected output for the 2 names:
|
||||
// RandomDataset->Repeat->Project->Shuffle->Batch->
|
||||
EXPECT_EQ(tree_1_names, tree_2_names);
|
||||
|
||||
ASSERT_OK(CompareTwoTrees(tree1, tree1, true));
|
||||
ASSERT_OK(CompareTwoTrees(tree1, tree2, false));
|
||||
|
||||
// verify compare function is correct
|
||||
EXPECT_TRUE(CompareTwoTrees(tree2, tree2, false).IsError());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestIRNodes, MindDataTestZipDeepCopy) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestZipDeepCopy.";
|
||||
|
||||
auto branch1 = RandomData(44)->Project({"label"});
|
||||
auto branch2 = RandomData(44)->Shuffle(10);
|
||||
|
||||
auto tree1 = Zip({branch1, branch2})->Batch(2)->IRNode();
|
||||
|
||||
auto tree2 = tree1->DeepCopy();
|
||||
std::string tree_1_names, tree_2_names;
|
||||
|
||||
ASSERT_OK(PostOrderPrintTree(tree1, tree_1_names));
|
||||
ASSERT_OK(PostOrderPrintTree(tree2, tree_2_names));
|
||||
|
||||
// expected output for the 2 names:
|
||||
// RandomDataset->Project->RandomDataset->Shuffle->Zip->Batch->
|
||||
EXPECT_EQ(tree_1_names, tree_2_names);
|
||||
|
||||
// verify the pointer within the same tree are the same
|
||||
ASSERT_OK(CompareTwoTrees(tree1, tree1, true));
|
||||
// verify two trees
|
||||
ASSERT_OK(CompareTwoTrees(tree1, tree2, false));
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestIRNodes, MindDataTestNodeRemove) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestIRNodes-MindDataTestNodeRemove.";
|
||||
|
||||
auto branch1 = RandomData(44)->Project({"label"});
|
||||
auto branch2 = ImageFolder("path");
|
||||
auto tree = Zip({branch1, branch2})->IRNode();
|
||||
/***
|
||||
tree looks like this, we will remove node and test its functionalities
|
||||
Zip
|
||||
/ \
|
||||
Project ImageFolder
|
||||
/
|
||||
RandomData
|
||||
***/
|
||||
auto tree_copy_1 = tree->DeepCopy();
|
||||
ASSERT_EQ(tree_copy_1->Children().size(), 2);
|
||||
// remove the project in the tree and test
|
||||
ASSERT_OK(tree_copy_1->Children()[0]->Remove()); // remove Project from tree
|
||||
ASSERT_OK(CompareTwoTrees(tree_copy_1, Zip({RandomData(44), ImageFolder("path")})->IRNode(), false));
|
||||
// remove the ImageFolder, a leaf node from the tree
|
||||
std::string tree_1_names, tree_2_names;
|
||||
ASSERT_OK(PostOrderPrintTree(tree_copy_1, tree_1_names));
|
||||
EXPECT_EQ(tree_1_names, "RandomDataset->ImageFolderDataset->Zip->");
|
||||
auto tree_copy_2 = tree->DeepCopy();
|
||||
ASSERT_EQ(tree_copy_2->Children().size(), 2);
|
||||
tree_copy_2->Children()[1]->Remove();
|
||||
ASSERT_OK(PostOrderPrintTree(tree_copy_2, tree_2_names));
|
||||
EXPECT_EQ(tree_2_names, "RandomDataset->Project->Zip->");
|
||||
}
|
Loading…
Reference in New Issue