!3294 Refactor opt/pre passes
Merge pull request !3294 from nsyca/removal_pass
This commit is contained in:
commit
df1300d9cb
|
@ -23,7 +23,7 @@
|
|||
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
|
||||
#include "minddata/dataset/engine/opt/post/repeat_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/injection_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h"
|
||||
#include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
|
||||
#include "minddata/dataset/engine/perf/profiling.h"
|
||||
#include "minddata/dataset/engine/perf/monitor.h"
|
||||
|
@ -225,7 +225,7 @@ Status ExecutionTree::PrepareTreePreAction() {
|
|||
std::vector<std::unique_ptr<Pass>> pre_actions;
|
||||
// Construct pre actions
|
||||
MS_LOG(INFO) << "Running pre pass loops.";
|
||||
pre_actions.push_back(std::make_unique<InjectionPass>());
|
||||
pre_actions.push_back(std::make_unique<EpochInjectionPass>());
|
||||
pre_actions.push_back(std::make_unique<RemovalPass>());
|
||||
pre_actions.push_back(std::make_unique<CacheTransformPass>());
|
||||
// Apply pre action passes
|
||||
|
|
|
@ -3,10 +3,8 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE
|
|||
add_library(engine-opt OBJECT
|
||||
pass.cc
|
||||
post/repeat_pass.cc
|
||||
pre/cache_pass.cc
|
||||
pre/cache_transform_pass.cc
|
||||
pre/injection_pass.cc
|
||||
pre/removal_nodes.cc
|
||||
pre/epoch_injection_pass.cc
|
||||
pre/removal_pass.cc
|
||||
optional/tensor_op_fusion_pass.cc
|
||||
util/printer_pass.cc
|
||||
|
|
|
@ -1,181 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "minddata/dataset/engine/opt/pre/cache_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/generator_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// Constructor
|
||||
CachePass::CachePass(CacheTransformPass *transform_pass)
|
||||
: transform_pass_(transform_pass), is_caching_(false), leaf_op_(nullptr) {}
|
||||
|
||||
// Identifies the subtree below this node as a cached descendant tree.
|
||||
Status CachePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
|
||||
if (is_caching_) {
|
||||
RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!");
|
||||
}
|
||||
is_caching_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache
|
||||
// transformation
|
||||
Status CachePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
is_caching_ = false; // We a no longer in a cache subtree. clear the flag.
|
||||
if (leaf_op_) {
|
||||
MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache.";
|
||||
// Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op,
|
||||
// using base class pointers.
|
||||
transform_pass_->AddMappableCacheOperators(std::move(leaf_op_), node);
|
||||
} else {
|
||||
// If there was no leaf_op set, then this is a non-mappable scenario.
|
||||
|
||||
if (sampler_) {
|
||||
// Grab the sampler that was saved from the leaf and plug it into the cache op
|
||||
node->SetSampler(std::move(sampler_));
|
||||
MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf.";
|
||||
} else {
|
||||
// We're a cache op but no sampler was saved from leaf, so create a default sampler
|
||||
const int64_t num_samples = 0;
|
||||
const int64_t start_index = 0;
|
||||
sampler_ = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||
node->SetSampler(std::move(sampler_));
|
||||
MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op.";
|
||||
}
|
||||
|
||||
// Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache
|
||||
uint32_t cache_crc = DatasetOp::GenerateCRC(node);
|
||||
RETURN_IF_NOT_OK(node->CreateCache(cache_crc));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Common code for mappable leaf setup.
|
||||
Status CachePass::MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
|
||||
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
|
||||
if (is_caching_ && leaf_op_) {
|
||||
RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache.");
|
||||
}
|
||||
|
||||
// If we are a leaf in the caching path, then save this leaf.
|
||||
if (is_caching_) {
|
||||
MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected";
|
||||
leaf_op_ = std::move(leaf_op);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Common code for non mappable leaf setup.
|
||||
Status CachePass::NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
|
||||
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
|
||||
if (is_caching_ && leaf_op_) {
|
||||
RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache.");
|
||||
}
|
||||
|
||||
// Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf
|
||||
// as save it for use by cache op in ascendant tree.
|
||||
if (is_caching_) {
|
||||
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_));
|
||||
MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected";
|
||||
} else {
|
||||
// If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can
|
||||
// remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based)
|
||||
std::shared_ptr<Sampler> sampler_from_leaf;
|
||||
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
|
||||
if (is_caching_) {
|
||||
// If we are a TF Reader in a caching tree, then change our config so that it becomes a basic
|
||||
// TF reader that parses all files. Selection of data will come from the sampler on the cache instead.
|
||||
node->MakeSimpleProducer();
|
||||
}
|
||||
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CachePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) {
|
||||
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -1,141 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class CacheTransformPass;
|
||||
|
||||
/// \class CachePass cache_pass.h
|
||||
/// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache
|
||||
/// transformation. It works in conjunction with the CacheTransformPass
|
||||
class CachePass : public NodePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
/// \param[in] transform_pass Raw pointer back to controlling tree pass
|
||||
explicit CachePass(CacheTransformPass *transform_pass);
|
||||
|
||||
/// \brief Destructor
|
||||
~CachePass() = default;
|
||||
|
||||
/// \brief Identifies the subtree below this node as a cached descendant tree.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache
|
||||
/// transformation
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override;
|
||||
|
||||
private:
|
||||
/// \brief Common code for mappable leaf setup.
|
||||
/// \param[in] node The leaf node performing setup work.
|
||||
/// \return Status The error code return
|
||||
Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
|
||||
|
||||
/// \brief Common code for non-mappable leaf setup.
|
||||
/// \param[in] node The leaf node performing setup work.
|
||||
/// \return Status The error code return
|
||||
Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
|
||||
|
||||
bool is_caching_;
|
||||
std::shared_ptr<DatasetOp> leaf_op_;
|
||||
std::shared_ptr<Sampler> sampler_;
|
||||
CacheTransformPass *transform_pass_; // Back pointer to the owning transform pass
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_
|
|
@ -15,17 +15,177 @@
|
|||
*/
|
||||
|
||||
#include <vector>
|
||||
#include "minddata/dataset/engine/opt/pre/cache_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/cache/cache_client.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/generator_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// Constructor
|
||||
CacheTransformPass::CachePass::CachePass() : is_caching_(false), leaf_op_(nullptr) {}
|
||||
|
||||
// Identifies the subtree below this node as a cached descendant tree.
|
||||
Status CacheTransformPass::CachePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
|
||||
if (is_caching_) {
|
||||
RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!");
|
||||
}
|
||||
is_caching_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache
|
||||
// transformation
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
is_caching_ = false; // We a no longer in a cache subtree. clear the flag.
|
||||
if (leaf_op_) {
|
||||
MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache.";
|
||||
// Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op,
|
||||
// using base class pointers.
|
||||
AddMappableCacheOperators(std::move(leaf_op_), node);
|
||||
} else {
|
||||
// If there was no leaf_op set, then this is a non-mappable scenario.
|
||||
|
||||
if (sampler_) {
|
||||
// Grab the sampler that was saved from the leaf and plug it into the cache op
|
||||
node->SetSampler(std::move(sampler_));
|
||||
MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf.";
|
||||
} else {
|
||||
// We're a cache op but no sampler was saved from leaf, so create a default sampler
|
||||
int64_t num_samples = 0;
|
||||
int64_t start_index = 0;
|
||||
sampler_ = std::make_shared<SequentialSampler>(num_samples, start_index);
|
||||
node->SetSampler(std::move(sampler_));
|
||||
MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op.";
|
||||
}
|
||||
|
||||
// Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache
|
||||
uint32_t cache_crc = DatasetOp::GenerateCRC(node);
|
||||
RETURN_IF_NOT_OK(node->CreateCache(cache_crc));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Common code for mappable leaf setup.
|
||||
Status CacheTransformPass::CachePass::MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
|
||||
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
|
||||
if (is_caching_ && leaf_op_) {
|
||||
RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache.");
|
||||
}
|
||||
|
||||
// If we are a leaf in the caching path, then save this leaf.
|
||||
if (is_caching_) {
|
||||
MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected";
|
||||
leaf_op_ = std::move(leaf_op);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Common code for non mappable leaf setup.
|
||||
Status CacheTransformPass::CachePass::NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
|
||||
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
|
||||
if (is_caching_ && leaf_op_) {
|
||||
RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache.");
|
||||
}
|
||||
|
||||
// Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf
|
||||
// as save it for use by cache op in ascendant tree.
|
||||
if (is_caching_) {
|
||||
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_));
|
||||
MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected";
|
||||
} else {
|
||||
// If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can
|
||||
// remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based)
|
||||
std::shared_ptr<Sampler> sampler_from_leaf;
|
||||
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
|
||||
if (is_caching_) {
|
||||
// If we are a TF Reader in a caching tree, then change our config so that it becomes a basic
|
||||
// TF reader that parses all files. Selection of data will come from the sampler on the cache instead.
|
||||
node->MakeSimpleProducer();
|
||||
}
|
||||
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) {
|
||||
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Perform leaf node cache tranform identifications
|
||||
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
|
||||
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
|
||||
// Assigns the leaf and cache operators that are involved in a cache transformation
|
||||
void CacheTransformPass::CachePass::AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op,
|
||||
std::shared_ptr<CacheOp> cache_op) {
|
||||
cache_pairs_.push_back(std::make_pair(leaf_op, cache_op));
|
||||
}
|
||||
|
||||
// constructor
|
||||
CacheTransformPass::CacheTransformPass() {}
|
||||
|
||||
|
@ -34,11 +194,11 @@ Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *modified) {
|
|||
MS_LOG(INFO) << "Pre pass: Cache transform pass started.";
|
||||
// Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will
|
||||
// use to execute a transform.
|
||||
std::unique_ptr<Pass> cache_pass = std::make_unique<CachePass>(this);
|
||||
RETURN_IF_NOT_OK(cache_pass->Run(tree, modified));
|
||||
CachePass cache_pass = CachePass();
|
||||
RETURN_IF_NOT_OK(cache_pass.Run(tree, modified));
|
||||
|
||||
// Then, execute the transform for each pair
|
||||
for (auto cache_pair : cache_pairs_) {
|
||||
for (auto cache_pair : cache_pass.cache_pairs()) {
|
||||
MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform.";
|
||||
ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client());
|
||||
}
|
||||
|
@ -98,11 +258,5 @@ Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::share
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Assigns the leaf and cache operators that are involved in a cache transformation
|
||||
void CacheTransformPass::AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op,
|
||||
std::shared_ptr<CacheOp> cache_op) {
|
||||
cache_pairs_.push_back(std::make_pair(leaf_op, cache_op));
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,6 +33,123 @@ class CacheClient;
|
|||
/// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching
|
||||
/// operations
|
||||
class CacheTransformPass : public TreePass {
|
||||
/// \class CachePass
|
||||
/// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache
|
||||
/// transformation. It works in conjunction with the CacheTransformPass
|
||||
class CachePass : public NodePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
/// \param[in] transform_pass Raw pointer back to controlling tree pass
|
||||
CachePass();
|
||||
|
||||
/// \brief Destructor
|
||||
~CachePass() = default;
|
||||
|
||||
/// \brief Identifies the subtree below this node as a cached descendant tree.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Resets the tracking of the cache within the tree and assigns the operators that
|
||||
/// will be involved in a cache transformation
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache tranform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Getter
|
||||
std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs() { return cache_pairs_; }
|
||||
|
||||
private:
|
||||
/// \brief Common code for mappable leaf setup.
|
||||
/// \param[in] node The leaf node performing setup work.
|
||||
/// \return Status The error code return
|
||||
Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
|
||||
|
||||
/// \brief Common code for non-mappable leaf setup.
|
||||
/// \param[in] node The leaf node performing setup work.
|
||||
/// \return Status The error code return
|
||||
Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
|
||||
|
||||
/// \brief Assigns the leaf and cache operators that are involved in a cache transformation
|
||||
/// \param[in] leaf_op The leaf operator involved in the cache transform
|
||||
/// \param[in] cache_op The cache operator involved in the cache transform
|
||||
void AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, std::shared_ptr<CacheOp> cache_op);
|
||||
|
||||
bool is_caching_;
|
||||
std::shared_ptr<DatasetOp> leaf_op_;
|
||||
std::shared_ptr<Sampler> sampler_;
|
||||
// The two operators that work together to establish the cache transform
|
||||
std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs_;
|
||||
};
|
||||
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CacheTransformPass();
|
||||
|
@ -46,11 +163,6 @@ class CacheTransformPass : public TreePass {
|
|||
/// \return Status The error code return
|
||||
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
|
||||
|
||||
/// \brief Assigns the leaf and cache operators that are involved in a cache transformation
|
||||
/// \param[in] leaf_op The leaf operator involved in the cache transform
|
||||
/// \param[in] cache_op The cache operator involved in the cache transform
|
||||
void AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, std::shared_ptr<CacheOp> cache_op);
|
||||
|
||||
private:
|
||||
/// \brief Helper function to execute the cache transformation.
|
||||
///
|
||||
|
@ -72,9 +184,6 @@ class CacheTransformPass : public TreePass {
|
|||
/// \return Status The error code return
|
||||
Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op,
|
||||
std::shared_ptr<DatasetOp> cache_op, std::shared_ptr<CacheClient> cache_client);
|
||||
|
||||
// The two operators that work together to establish the cache transform
|
||||
std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "minddata/dataset/engine/opt/pre/injection_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
|
||||
|
@ -25,64 +25,55 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
|
||||
// constructor
|
||||
InjectionPass::InjectionFinder::InjectionFinder(InjectionPass *injection_pass) : injection_pass_(injection_pass) {}
|
||||
EpochInjectionPass::InjectionFinder::InjectionFinder(std::shared_ptr<DatasetOp> node) : injection_point_(node) {}
|
||||
|
||||
// Performs finder work for BuildVocabOp that has special rules about epoch control injection
|
||||
Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) {
|
||||
if (injection_pass_) {
|
||||
injection_pass_->epoch_ctrl_bypass_ = true;
|
||||
return Status::OK();
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!");
|
||||
}
|
||||
Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) {
|
||||
injection_point_ = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection
|
||||
Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) {
|
||||
if (injection_pass_) {
|
||||
injection_pass_->epoch_ctrl_bypass_ = true;
|
||||
return Status::OK();
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!");
|
||||
}
|
||||
Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node,
|
||||
bool *modified) {
|
||||
injection_point_ = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Temporary code to prevent the injection of epoch control when cache op is present
|
||||
// Remove this code in cache op phase 2
|
||||
Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
if (injection_pass_) {
|
||||
injection_pass_->epoch_ctrl_bypass_ = true;
|
||||
return Status::OK();
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!");
|
||||
}
|
||||
Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
injection_point_ = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EpochInjectionPass::InjectionFinder::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) {
|
||||
// Assumption: There is only one DeviceQueueOp in a pipeline. This assumption is not validated here.
|
||||
injection_point_ = node->child(0);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// constructor
|
||||
InjectionPass::InjectionPass() : epoch_ctrl_bypass_(false) {}
|
||||
EpochInjectionPass::EpochInjectionPass() {}
|
||||
|
||||
// Runs an injection pass to inject in operators needed at the pre pass stage
|
||||
Status InjectionPass::RunOnTree(ExecutionTree *tree, bool *modified) {
|
||||
Status EpochInjectionPass::RunOnTree(ExecutionTree *tree, bool *modified) {
|
||||
MS_LOG(INFO) << "Pre pass: Injection pass started.";
|
||||
|
||||
// First, run the finder to perform any injection info before we can go ahead to drive the op injection work.
|
||||
// The finder can make updates to the InjectionPass object.
|
||||
InjectionPass::InjectionFinder finder(this);
|
||||
finder.Run(tree, modified);
|
||||
// The finder can make updates to the EpochInjectionPass object.
|
||||
EpochInjectionPass::InjectionFinder finder(tree->root());
|
||||
RETURN_IF_NOT_OK(finder.Run(tree, modified));
|
||||
|
||||
// The first injection logic is to check if we should inject the epoch control op as the root node.
|
||||
// Do not inject the op if the number of epochs is 1.
|
||||
int32_t num_epochs = tree->num_epochs();
|
||||
if (num_epochs != 1 && !epoch_ctrl_bypass_) {
|
||||
std::shared_ptr<DatasetOp> epoch_inject_node = finder.injection_point();
|
||||
if (num_epochs != 1 && epoch_inject_node != nullptr) {
|
||||
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op;
|
||||
RETURN_IF_NOT_OK(EpochCtrlOp::Builder(num_epochs).Build(&epoch_ctrl_op));
|
||||
RETURN_IF_NOT_OK(tree->AssociateNode(epoch_ctrl_op));
|
||||
std::shared_ptr<DatasetOp> node = tree->root();
|
||||
if (std::dynamic_pointer_cast<DeviceQueueOp>(node) == nullptr) {
|
||||
tree->root()->InsertAsParent(epoch_ctrl_op);
|
||||
} else {
|
||||
tree->root()->child(0)->InsertAsParent(epoch_ctrl_op);
|
||||
}
|
||||
epoch_inject_node->InsertAsParent(epoch_ctrl_op);
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Pre pass: Injection pass complete.";
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
|
||||
#define DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
|
||||
#ifndef DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
|
||||
#define DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
@ -26,10 +26,10 @@ namespace dataset {
|
|||
|
||||
class DatasetOp;
|
||||
|
||||
/// \class InjectionPass injection_pass.h
|
||||
/// \class EpochInjectionPass epoch_injection_pass.h
|
||||
/// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api
|
||||
/// parsing.
|
||||
class InjectionPass : public TreePass {
|
||||
class EpochInjectionPass : public TreePass {
|
||||
/// \class InjectionFinder
|
||||
/// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for
|
||||
/// operators that need to be injected. It is run first by the main injection pass to find out what operators
|
||||
|
@ -37,7 +37,10 @@ class InjectionPass : public TreePass {
|
|||
class InjectionFinder : public NodePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit InjectionFinder(InjectionPass *injection_pass);
|
||||
explicit InjectionFinder(std::shared_ptr<DatasetOp> node);
|
||||
|
||||
/// \brief Destructor
|
||||
~InjectionFinder() = default;
|
||||
|
||||
/// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection.
|
||||
/// \param[in] node The node being visited
|
||||
|
@ -58,24 +61,30 @@ class InjectionPass : public TreePass {
|
|||
/// \return Status The error code return
|
||||
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Register the DeviceQueueOp for further action.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Getter
|
||||
std::shared_ptr<DatasetOp> injection_point() { return injection_point_; }
|
||||
|
||||
private:
|
||||
InjectionPass *injection_pass_;
|
||||
std::shared_ptr<DatasetOp> injection_point_;
|
||||
};
|
||||
|
||||
public:
|
||||
/// \brief Constructor
|
||||
InjectionPass();
|
||||
EpochInjectionPass();
|
||||
|
||||
/// \brief Runs an injection pass to inject in operators needed at the pre pass stage
|
||||
/// \param[inout] tree The tree to operate on.
|
||||
/// \param[inout] Indicate of the tree was modified.
|
||||
/// \return Status The error code return
|
||||
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
|
||||
|
||||
private:
|
||||
bool epoch_ctrl_bypass_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
|
||||
#endif // DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
|
|
@ -1,58 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "minddata/dataset/engine/opt/pre/removal_nodes.h"
|
||||
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
|
||||
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {}
|
||||
|
||||
// Identifies the subtree below this node as a cached descendant tree.
|
||||
Status RemovalNodes::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree.";
|
||||
is_caching_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Resets the tracking of the cache within the tree
|
||||
Status RemovalNodes::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
MS_LOG(INFO) << "Removal pass: cache descendant tree complete.";
|
||||
is_caching_ = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Perform ShuffleOp removal check.
|
||||
Status RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
// If we are in a cache descendant tree, then this shuffle op needs to be removed
|
||||
if (is_caching_) {
|
||||
MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)";
|
||||
if (removal_pass_) {
|
||||
removal_pass_->AddToRemovalList(std::static_pointer_cast<DatasetOp>(node));
|
||||
} else {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Back reference to removal pass is missing!");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -1,64 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_
|
||||
|
||||
#include <memory>
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// \class RemovalNodes removal_nodes.h
|
||||
/// \brief This is a NodePass who's job is to identify which nodes should be removed.
|
||||
/// It works in conjunction with the removal_pass.
|
||||
class RemovalNodes : public NodePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
/// \param[in] removal_pass Raw pointer back to controlling tree pass
|
||||
explicit RemovalNodes(RemovalPass *removal_pass);
|
||||
|
||||
/// \brief Identifies the subtree below this node as a cached descendant tree.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Resets the tracking of the cache within the tree
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Destructor
|
||||
~RemovalNodes() = default;
|
||||
|
||||
/// \brief Perform ShuffleOp removal check
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
|
||||
|
||||
private:
|
||||
bool is_caching_;
|
||||
RemovalPass *removal_pass_; // Back pointer to the owning removal pass
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_
|
|
@ -16,32 +16,58 @@
|
|||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "minddata/dataset/engine/opt/pre/removal_nodes.h"
|
||||
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
|
||||
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
RemovalPass::RemovalNodes::RemovalNodes() : is_caching_(false) {}
|
||||
|
||||
// Identifies the subtree below this node as a cached descendant tree.
|
||||
Status RemovalPass::RemovalNodes::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree.";
|
||||
is_caching_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Resets the tracking of the cache within the tree
|
||||
Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
MS_LOG(INFO) << "Removal pass: cache descendant tree complete.";
|
||||
is_caching_ = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Perform ShuffleOp removal check.
|
||||
Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
|
||||
*modified = false;
|
||||
// If we are in a cache descendant tree, then this shuffle op needs to be removed
|
||||
if (is_caching_) {
|
||||
MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)";
|
||||
nodes_to_remove_.push_back(std::static_pointer_cast<DatasetOp>(node));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// constructor
|
||||
RemovalPass::RemovalPass() {}
|
||||
|
||||
// Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
|
||||
// Walk the tree to collect the nodes to remove, then removes them.
|
||||
Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) {
|
||||
MS_LOG(INFO) << "Pre pass: removal pass started.";
|
||||
// Create the removal node pass which can identify which nodes need to be removed.
|
||||
std::unique_ptr<Pass> removal_nodes = std::make_unique<RemovalNodes>(this);
|
||||
std::unique_ptr<RemovalPass::RemovalNodes> removal_nodes = std::make_unique<RemovalPass::RemovalNodes>();
|
||||
RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified));
|
||||
|
||||
// Then, execute the removal of any nodes that were set up for removal
|
||||
for (auto node : removal_nodes_) {
|
||||
for (auto node : removal_nodes->nodes_to_remove()) {
|
||||
node->Remove();
|
||||
}
|
||||
MS_LOG(INFO) << "Pre pass: removal pass complete.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Adds an operator to the list of operators to be removed
|
||||
void RemovalPass::AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op) { removal_nodes_.push_back(dataset_op); }
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,6 +30,45 @@ class DatasetOp;
|
|||
/// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which
|
||||
/// nodes should be removed, and then removes them.
|
||||
class RemovalPass : public TreePass {
|
||||
/// \class RemovalNodes
|
||||
/// \brief This is a NodePass who's job is to identify which nodes should be removed.
|
||||
/// It works in conjunction with the removal_pass.
|
||||
class RemovalNodes : public NodePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
/// \param[in] removal_pass Raw pointer back to controlling tree pass
|
||||
RemovalNodes();
|
||||
|
||||
/// \brief Destructor
|
||||
~RemovalNodes() = default;
|
||||
|
||||
/// \brief Identifies the subtree below this node as a cached descendant tree.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Resets the tracking of the cache within the tree
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform ShuffleOp removal check
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Getter
|
||||
/// \return All the nodes to be removed
|
||||
std::vector<std::shared_ptr<DatasetOp>> nodes_to_remove() { return nodes_to_remove_; }
|
||||
|
||||
private:
|
||||
bool is_caching_;
|
||||
std::vector<std::shared_ptr<DatasetOp>> nodes_to_remove_;
|
||||
};
|
||||
|
||||
public:
|
||||
/// \brief Constructor
|
||||
RemovalPass();
|
||||
|
@ -42,13 +81,6 @@ class RemovalPass : public TreePass {
|
|||
/// \param[inout] Indicate of the tree was modified.
|
||||
/// \return Status The error code return
|
||||
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
|
||||
|
||||
/// \brief Adds an operator to the list of operators to be removed
|
||||
/// \param[in] dataset_op The operator to add to the removal list
|
||||
void AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op);
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<DatasetOp>> removal_nodes_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -189,7 +189,7 @@ def test_minddataset_invalidate_num_shards():
|
|||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info)
|
||||
assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info.value)
|
||||
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
@ -203,7 +203,7 @@ def test_minddataset_invalidate_shard_id():
|
|||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info)
|
||||
assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info.value)
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
||||
|
@ -217,14 +217,14 @@ def test_minddataset_shard_id_bigger_than_num_shard():
|
|||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info)
|
||||
assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info.value)
|
||||
|
||||
with pytest.raises(Exception) as error_info:
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5)
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info)
|
||||
assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info.value)
|
||||
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
@ -245,7 +245,7 @@ def test_cv_minddataset_partition_num_samples_equals_0():
|
|||
num_iter += 1
|
||||
with pytest.raises(Exception) as error_info:
|
||||
partitions(5)
|
||||
assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info)
|
||||
assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info.value)
|
||||
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
|
Loading…
Reference in New Issue