forked from mindspore-Ecosystem/mindspore
!2602 Stage 2 of CacheOp delivery
Merge pull request !2602 from JesseKLee/cache_op_stage2
This commit is contained in:
commit
b57d4ea2f3
|
@ -128,7 +128,7 @@ Status ConcatOp::Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf) {
|
|||
|
||||
Status ConcatOp::PrepareNodePostAction() {
|
||||
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
|
||||
tree_->AddToRepeatStack(shared_from_this());
|
||||
tree_->AddToEOEOpStack(shared_from_this());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -18,23 +18,26 @@
|
|||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <regex>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
|
||||
#include "dataset/engine/execution_tree.h"
|
||||
#include "dataset/engine/datasetops/device_queue_op.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "dataset/engine/data_buffer.h"
|
||||
#include "dataset/engine/db_connector.h"
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
#include "utils/system/crc32c.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Constructor
|
||||
DatasetOp::DatasetOp(int32_t op_connector_size)
|
||||
DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler)
|
||||
: oc_queue_size_(op_connector_size),
|
||||
sampler_(sampler),
|
||||
operator_id_(kInvalidOperatorId),
|
||||
tree_(nullptr),
|
||||
state_(OpState::kDeOpIdle),
|
||||
|
@ -150,6 +153,9 @@ void DatasetOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
out << "\nConnector queue size : " << oc_queue_size_ << "\nOperator control flags : 0x" << std::hex
|
||||
<< std::setw(8) << std::setfill('0') << op_ctrl_flags_ << std::dec << std::setfill(' ');
|
||||
if (sampler_) {
|
||||
sampler_->Print(out, show_all);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -222,11 +228,10 @@ Status DatasetOp::PrepareNodePreAction() {
|
|||
Status DatasetOp::PrepareNodePostAction() {
|
||||
// If this op does not have any children and it is in a repeat path of the tree...
|
||||
if (child_.empty() && BitTest(op_ctrl_flags_, kDeOpRepeated)) {
|
||||
// push ourselves onto the tree repeat stack. Later, the repeat operator
|
||||
// push ourselves onto the eoe operator stack. Later, a repeat/epoch ctrl operator
|
||||
// above us will consume them.
|
||||
tree_->AddToRepeatStack(shared_from_this());
|
||||
tree_->AddToEOEOpStack(shared_from_this());
|
||||
}
|
||||
|
||||
// Creating Connector object for each op.
|
||||
// The consumer of the root node is assumed to be one thread.
|
||||
// If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion.
|
||||
|
@ -289,5 +294,56 @@ Status DatasetOp::Accept(NodePass *p, bool *modified) {
|
|||
// This method will only be called if its derived class does not implement one.
|
||||
return p->RunOnNode(shared_from_this(), modified);
|
||||
}
|
||||
|
||||
// A helper function with some common code that leaf nodes can use during
|
||||
// prepare phase for checking if they need to assign a sampler to the cache.
|
||||
// @return - Status
|
||||
Status DatasetOp::SaveSamplerForCache(bool random_access_op) {
|
||||
// If we are a descendant under a cache op and we have a sampler, then save this sampler
|
||||
// to a stack so that the cache can pick it up during it's processing above us.
|
||||
if (sampler_) {
|
||||
if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) {
|
||||
// use move semantic to set our sampler_ to null after the move. This is okay because a sampler is
|
||||
// useless to a random data op. It was only being used as a temporary holding until the cache can
|
||||
// be created
|
||||
tree_->AddToSamplerStack(sampler_);
|
||||
MS_LOG(INFO) << "Preparing a leaf op: passing sampler up the tree for Cache handling.";
|
||||
} else if (!random_access_op) {
|
||||
// A sampler exists, but we are not in a caching tree and we are not a random access mappable leaf.
|
||||
// This is an error because that type of leaf does not use sampling unless there's a cache to hook it into.
|
||||
return Status(
|
||||
StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
"Non-mappable leaf op has a sampler, but it only supports sampling if there is a cache after it in the tree");
|
||||
}
|
||||
}
|
||||
|
||||
if (!random_access_op) {
|
||||
// Since we don't truly need the sampler for this non-mappable dataset and it's been saved for the cache
|
||||
// we can remove it now from the base.
|
||||
sampler_.reset();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) {
|
||||
std::stringstream ss;
|
||||
op->tree_->Print(ss, op);
|
||||
std::string ss_str = ss.str();
|
||||
|
||||
// Filter out the Operator control flags field when generating the check sum
|
||||
ss_str = std::regex_replace(ss_str, std::regex("Operator control flags.*\n"), "");
|
||||
|
||||
// Filter out the Device id field to allow cache sharing for a distributed run of the same pipeline
|
||||
ss_str = std::regex_replace(ss_str, std::regex("Device id.*\n"), "");
|
||||
ss_str = std::regex_replace(ss_str, std::regex("device_id.*\n"), "");
|
||||
|
||||
// The Cache crc and Server cache id field is different when creating new cache_client and re-using the same
|
||||
// cache_client later. So we filter out these two fields to allow cache sharing.
|
||||
ss_str = std::regex_replace(ss_str, std::regex("Cache crc.*\n"), "");
|
||||
ss_str = std::regex_replace(ss_str, std::regex("Server cache id.*\n"), "");
|
||||
|
||||
uint32_t cache_crc = system::Crc32c::GetMaskCrc32cValue(ss_str.c_str(), ss_str.length());
|
||||
return cache_crc;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,6 +34,8 @@ class DataBuffer;
|
|||
|
||||
class NodePass;
|
||||
|
||||
class Sampler;
|
||||
|
||||
// The base class DatasetOp is the main tree node. It is an abstract class, so
|
||||
// the actual implementation of the operators will be derived from here.
|
||||
class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
||||
|
@ -55,7 +57,8 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
|
||||
// Constructor
|
||||
// @param op_connector_size - The size for the output connector of this operator.
|
||||
explicit DatasetOp(int32_t op_connector_size);
|
||||
// @param sampler - The sampler for the op
|
||||
explicit DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler);
|
||||
|
||||
// Destructor
|
||||
virtual ~DatasetOp() { tree_ = nullptr; }
|
||||
|
@ -204,6 +207,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
// @return Sets the control flags
|
||||
void set_control_flag(uint64_t flag) { BitSet(&op_ctrl_flags_, flag); }
|
||||
|
||||
// Setter function
|
||||
// @return Sets the control flags
|
||||
void ClearControlFlag(uint64_t flag) { BitClear(&op_ctrl_flags_, flag); }
|
||||
|
||||
// Register the internal worker connectors. No op unless it is a parallel op
|
||||
// @return Status
|
||||
virtual Status RegisterWorkerConnectors() { return Status::OK(); }
|
||||
|
@ -271,6 +278,13 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
// @return Pointer to the ExecutionTree the current op belongs to, no ownership
|
||||
ExecutionTree *Tree() { return tree_; }
|
||||
|
||||
// Getter for the sampler
|
||||
// @return Shared pointer to the sampler (may return nullptr)
|
||||
std::shared_ptr<Sampler> sampler() { return sampler_; }
|
||||
|
||||
// Computes a CRC value for the operator
|
||||
static uint32_t GenerateCRC(const std::shared_ptr<DatasetOp> &op);
|
||||
|
||||
protected:
|
||||
// Adds a parent operator to this operator
|
||||
// @notes External callers do not have access to this function.
|
||||
|
@ -289,8 +303,15 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
// @return - Status
|
||||
virtual Status ComputeColMap();
|
||||
|
||||
// A helper function with some common code that leaf nodes can use during
|
||||
// prepare phase for checking if they need to assign a sampler to the cache.
|
||||
// @param random_access_op - indicate if this is a mappable random access leaf or not
|
||||
// @return - Status
|
||||
Status SaveSamplerForCache(bool random_access_op);
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes
|
||||
std::vector<DatasetOp *> parent_; // Parent nodes. No ownership
|
||||
std::shared_ptr<Sampler> sampler_; // Some leaf ops might have a sampler
|
||||
int32_t oc_queue_size_; // Capacity for each out_connector_
|
||||
int32_t operator_id_; // Generated id for the node
|
||||
ExecutionTree *tree_; // Back pointer to our tree.
|
||||
|
|
|
@ -100,7 +100,7 @@ void MapOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
out << "\n TensorOps:";
|
||||
for (size_t i = 0; i < tfuncs_.size(); i++) {
|
||||
out << " " << tfuncs_[i];
|
||||
out << " " << *(tfuncs_[i].get());
|
||||
}
|
||||
out << "\n\n";
|
||||
}
|
||||
|
|
|
@ -26,8 +26,8 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Constructor
|
||||
ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size)
|
||||
: DatasetOp(op_connector_size),
|
||||
ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<Sampler> sampler)
|
||||
: DatasetOp(op_connector_size, sampler),
|
||||
num_workers_(num_workers),
|
||||
num_producers_(num_workers),
|
||||
worker_connector_size_(1),
|
||||
|
|
|
@ -38,7 +38,8 @@ class ParallelOp : public DatasetOp {
|
|||
// Constructor
|
||||
// @param num_workers
|
||||
// @param op_connector_size - size of the output connector for this operator
|
||||
ParallelOp(int32_t num_workers, int32_t op_connector_size);
|
||||
// @param sampler - The sampler for the op
|
||||
ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<Sampler> sampler = nullptr);
|
||||
|
||||
// Destructor
|
||||
~ParallelOp() = default;
|
||||
|
|
|
@ -20,7 +20,8 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Constructor
|
||||
PipelineOp::PipelineOp(int32_t op_connector_size) : DatasetOp(op_connector_size) {}
|
||||
PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler)
|
||||
: DatasetOp(op_connector_size, sampler) {}
|
||||
|
||||
// A print method typically used for debugging
|
||||
void PipelineOp::Print(std::ostream &out, bool show_all) const {
|
||||
|
|
|
@ -32,7 +32,8 @@ class PipelineOp : public DatasetOp {
|
|||
// Constructor
|
||||
// @param op_connector_size - size of the output connector
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
explicit PipelineOp(int32_t op_connector_size);
|
||||
// @param sampler - The sampler for the op
|
||||
explicit PipelineOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler = nullptr);
|
||||
|
||||
// Destructor
|
||||
~PipelineOp() = default;
|
||||
|
|
|
@ -82,14 +82,14 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {
|
|||
Status RepeatOp::PrepareNodePostAction() {
|
||||
// Run any common code from super class first before adding our own specific logic
|
||||
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
|
||||
std::shared_ptr<DatasetOp> leaf_op = tree_->PopFromRepeatStack();
|
||||
std::shared_ptr<DatasetOp> leaf_op = tree_->PopFromEOEOpStack();
|
||||
while (leaf_op != nullptr) {
|
||||
// Track the leaf operators that are under this repeat op.
|
||||
eoe_ops_.push_back(leaf_op);
|
||||
leaf_op = tree_->PopFromRepeatStack();
|
||||
leaf_op = tree_->PopFromEOEOpStack();
|
||||
}
|
||||
// Push ourselves to the stack in case one of our ascendants is repeat too.
|
||||
tree_->AddToRepeatStack(shared_from_this());
|
||||
tree_->AddToEOEOpStack(shared_from_this());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -70,13 +70,12 @@ Status CelebAOp::Builder::SanityCheck() {
|
|||
CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size,
|
||||
bool decode, const std::string &dataset_type, const std::set<std::string> &exts,
|
||||
std::unique_ptr<DataSchema> schema, std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_workers, queue_size),
|
||||
: ParallelOp(num_workers, queue_size, std::move(sampler)),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
folder_path_(dir),
|
||||
decode_(decode),
|
||||
extensions_(exts),
|
||||
data_schema_(std::move(schema)),
|
||||
sampler_(std::move(sampler)),
|
||||
num_rows_in_attr_file_(0),
|
||||
dataset_type_(dataset_type) {
|
||||
attr_info_queue_ = std::make_unique<Queue<std::vector<std::string>>>(queue_size);
|
||||
|
|
|
@ -221,7 +221,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
|
|||
bool decode_;
|
||||
std::set<std::string> extensions_; // extensions allowed
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
std::shared_ptr<Sampler> sampler_;
|
||||
std::unique_ptr<Queue<std::vector<std::string>>> attr_info_queue_;
|
||||
int64_t num_rows_in_attr_file_; // rows number specified in attr file
|
||||
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
|
||||
|
|
|
@ -79,12 +79,11 @@ Status CifarOp::Builder::SanityCheck() {
|
|||
|
||||
CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir,
|
||||
int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_works, queue_size),
|
||||
: ParallelOp(num_works, queue_size, std::move(sampler)),
|
||||
cifar_type_(type),
|
||||
rows_per_buffer_(rows_per_buf),
|
||||
folder_path_(file_dir),
|
||||
data_schema_(std::move(data_schema)),
|
||||
sampler_(std::move(sampler)),
|
||||
row_cnt_(0),
|
||||
buf_cnt_(0) {
|
||||
constexpr uint64_t kUtilQueueSize = 512;
|
||||
|
|
|
@ -216,7 +216,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
|||
int32_t rows_per_buffer_;
|
||||
std::string folder_path_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
std::shared_ptr<Sampler> sampler_;
|
||||
int64_t row_cnt_;
|
||||
int64_t buf_cnt_;
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::str
|
|||
bool recursive, bool do_decode, const std::set<std::string> &exts,
|
||||
const std::map<std::string, int32_t> &map, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_wkrs, queue_size),
|
||||
: ParallelOp(num_wkrs, queue_size, std::move(sampler)),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
folder_path_(file_dir),
|
||||
recursive_(recursive),
|
||||
|
@ -73,7 +73,6 @@ ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::str
|
|||
extensions_(exts),
|
||||
class_index_(map),
|
||||
data_schema_(std::move(data_schema)),
|
||||
sampler_(std::move(sampler)),
|
||||
row_cnt_(0),
|
||||
buf_cnt_(0),
|
||||
sampler_ind_(0),
|
||||
|
|
|
@ -259,7 +259,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
|
|||
std::set<std::string> extensions_; // extensions allowed
|
||||
std::map<std::string, int32_t> class_index_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
std::shared_ptr<Sampler> sampler_;
|
||||
int64_t row_cnt_;
|
||||
int64_t buf_cnt_;
|
||||
int64_t sampler_ind_;
|
||||
|
|
|
@ -64,7 +64,7 @@ Status ManifestOp::Builder::SanityCheck() {
|
|||
ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode,
|
||||
const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<Sampler> sampler, std::string usage)
|
||||
: ParallelOp(num_works, queue_size),
|
||||
: ParallelOp(num_works, queue_size, std::move(sampler)),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
io_block_pushed_(0),
|
||||
row_cnt_(0),
|
||||
|
@ -72,7 +72,6 @@ ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string f
|
|||
data_schema_(std::move(data_schema)),
|
||||
file_(file),
|
||||
class_index_(class_index),
|
||||
sampler_(std::move(sampler)),
|
||||
decode_(decode),
|
||||
usage_(usage),
|
||||
buf_cnt_(0) {
|
||||
|
|
|
@ -230,7 +230,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
|
|||
std::unique_ptr<DataSchema> data_schema_;
|
||||
std::string file_; // file that store the information of images
|
||||
std::map<std::string, int32_t> class_index_;
|
||||
std::shared_ptr<Sampler> sampler_;
|
||||
bool decode_;
|
||||
std::string usage_;
|
||||
int64_t buf_cnt_;
|
||||
|
|
|
@ -66,12 +66,11 @@ Status MnistOp::Builder::SanityCheck() {
|
|||
|
||||
MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_workers, queue_size),
|
||||
: ParallelOp(num_workers, queue_size, std::move(sampler)),
|
||||
buf_cnt_(0),
|
||||
row_cnt_(0),
|
||||
folder_path_(folder_path),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
sampler_(std::move(sampler)),
|
||||
data_schema_(std::move(data_schema)) {
|
||||
io_block_queues_.Init(num_workers, queue_size);
|
||||
}
|
||||
|
|
|
@ -235,7 +235,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
|||
WaitPost wp_;
|
||||
std::string folder_path_; // directory of image folder
|
||||
int32_t rows_per_buffer_;
|
||||
std::shared_ptr<Sampler> sampler_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
std::vector<MnistLabelPair> image_label_pairs_;
|
||||
std::vector<std::string> image_names_;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "dataset/core/config_manager.h"
|
||||
#include "dataset/util/random.h"
|
||||
#include "dataset/util/wait_post.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -30,7 +31,8 @@ RandomDataOp::Builder::Builder()
|
|||
builder_num_workers_(0),
|
||||
builder_op_connector_size_(0),
|
||||
builder_rows_per_buffer_(0),
|
||||
builder_total_rows_(0) {
|
||||
builder_total_rows_(0),
|
||||
builder_sampler_(nullptr) {
|
||||
// Some arguments to the RandomDataOp have a default argument that is taken from the config.
|
||||
// The user may override these defaults by using the builder set methods.
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
|
@ -43,8 +45,9 @@ RandomDataOp::Builder::Builder()
|
|||
Status RandomDataOp::Builder::Build(std::shared_ptr<RandomDataOp> *out_op) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
|
||||
*out_op = std::make_shared<RandomDataOp>(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_,
|
||||
builder_total_rows_, std::move(builder_data_schema_));
|
||||
*out_op =
|
||||
std::make_shared<RandomDataOp>(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_,
|
||||
builder_total_rows_, std::move(builder_data_schema_), std::move(builder_sampler_));
|
||||
|
||||
// If the user did not provide a schema, then we will ask the op to generate a pseudo-random
|
||||
// schema.
|
||||
|
@ -66,8 +69,8 @@ Status RandomDataOp::Builder::SanityCheck() const {
|
|||
|
||||
// Constructor for RandomDataOp
|
||||
RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows,
|
||||
std::unique_ptr<DataSchema> data_schema)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
|
||||
buffer_id_(0),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
total_rows_(total_rows),
|
||||
|
@ -124,7 +127,7 @@ Status RandomDataOp::GenerateSchema() {
|
|||
// For each column:
|
||||
// - choose a datatype
|
||||
// - generate a shape that randomly chooses the number of dimensions and the dimension values.
|
||||
DataType::Type newType = static_cast<DataType::Type>(GenRandomInt(0, DataType::NUM_OF_TYPES - 2));
|
||||
DataType::Type newType = static_cast<DataType::Type>(GenRandomInt(1, DataType::NUM_OF_TYPES - 2));
|
||||
int32_t rank = GenRandomInt(1, kMaxRank);
|
||||
std::vector<dsize_t> dims;
|
||||
for (int32_t d = 0; d < rank; d++) {
|
||||
|
@ -412,5 +415,15 @@ Status RandomDataOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
||||
// their role.
|
||||
Status RandomDataOp::PrepareNodePostAction() {
|
||||
// Run common code from super class before adding RandomDataOp specific handling
|
||||
RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction());
|
||||
// Specific handling for this op, we need to do cache op work to assign the sampler to the cache.
|
||||
RETURN_IF_NOT_OK(DatasetOp::SaveSamplerForCache(false));
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -42,7 +42,7 @@ class RandomDataOp : public ParallelOp {
|
|||
// Some constants to provide limits to random generation.
|
||||
static constexpr int32_t kMaxNumColumns = 4;
|
||||
static constexpr int32_t kMaxRank = 4;
|
||||
static constexpr int32_t kMaxDimValue = 2048;
|
||||
static constexpr int32_t kMaxDimValue = 32;
|
||||
static constexpr int32_t kMaxTotalRows = 1024;
|
||||
|
||||
// A nested builder class to aid in the construction of a RandomDataOp
|
||||
|
@ -117,6 +117,14 @@ class RandomDataOp : public ParallelOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param std::shared_ptr<Sampler> sampler
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
|
||||
builder_sampler_ = std::move(sampler);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
/**
|
||||
* Check if the required parameters are set by the builder.
|
||||
|
@ -125,6 +133,7 @@ class RandomDataOp : public ParallelOp {
|
|||
Status SanityCheck() const;
|
||||
|
||||
std::unique_ptr<DataSchema> builder_data_schema_;
|
||||
std::shared_ptr<Sampler> builder_sampler_;
|
||||
int32_t builder_num_workers_;
|
||||
int32_t builder_op_connector_size_;
|
||||
int64_t builder_rows_per_buffer_;
|
||||
|
@ -139,10 +148,11 @@ class RandomDataOp : public ParallelOp {
|
|||
* @param rows_per_buffer - The number of rows in each DataBuffer
|
||||
* @param data_schema - A user-provided schema
|
||||
* @param total_rows - The total number of rows in the dataset
|
||||
* @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes
|
||||
* @return Builder - The modified builder by reference
|
||||
*/
|
||||
RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows,
|
||||
std::unique_ptr<DataSchema> data_schema);
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler);
|
||||
|
||||
/**
|
||||
* Destructor
|
||||
|
@ -193,6 +203,12 @@ class RandomDataOp : public ParallelOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return "RandomDataOp"; }
|
||||
|
||||
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
||||
// their role.
|
||||
// @notes Derived versions of this function should always call it's superclass version first
|
||||
// before providing their own implementations.
|
||||
Status PrepareNodePostAction() override;
|
||||
|
||||
private:
|
||||
/**
|
||||
* The entry point code for when workers are launched
|
||||
|
|
|
@ -107,12 +107,11 @@ Status DistributedSampler::ResetSampler() {
|
|||
}
|
||||
|
||||
void DistributedSampler::Print(std::ostream &out, bool show_all) const {
|
||||
out << "(sampler): DistributedSampler\n";
|
||||
out << "\nSampler: DistributedSampler";
|
||||
if (show_all) {
|
||||
out << "seed_: " << seed_ << '\n';
|
||||
out << "device_id_: " << device_id_ << '\n';
|
||||
out << "num_devices_: " << num_devices_ << '\n';
|
||||
out << "shuffle_: " << shuffle_ << '\n';
|
||||
Sampler::Print(out, show_all);
|
||||
out << "\nseed: " << seed_ << "\ndevice_id: " << device_id_ << "\nnum_devices: " << num_devices_
|
||||
<< "\nshuffle: " << shuffle_;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -113,5 +113,13 @@ Status PKSampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
void PKSampler::Print(std::ostream &out, bool show_all) const {
|
||||
out << "\nSampler: PKSampler";
|
||||
if (show_all) {
|
||||
// Call the super class for displaying any common detailed info
|
||||
Sampler::Print(out, show_all);
|
||||
// Then add our own info if any
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -56,6 +56,11 @@ class PKSampler : public Sampler { // NOT YET FINISHED
|
|||
// @return - The error code return
|
||||
Status ResetSampler() override;
|
||||
|
||||
// Printer for debugging purposes.
|
||||
// @param out - output stream to write to
|
||||
// @param show_all - bool to show detailed vs summary
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
private:
|
||||
bool shuffle_;
|
||||
uint32_t seed_;
|
||||
|
|
|
@ -103,5 +103,14 @@ Status PythonSampler::ResetSampler() {
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void PythonSampler::Print(std::ostream &out, bool show_all) const {
|
||||
out << "\nSampler: PythonSampler";
|
||||
if (show_all) {
|
||||
// Call the super class for displaying any common detailed info
|
||||
Sampler::Print(out, show_all);
|
||||
// Then add our own info if any
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -50,6 +50,11 @@ class PythonSampler : public Sampler {
|
|||
// @return - The error code return
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
// Printer for debugging purposes.
|
||||
// @param out - output stream to write to
|
||||
// @param show_all - bool to show detailed vs summary
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
private:
|
||||
bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer()
|
||||
|
||||
|
|
|
@ -113,13 +113,12 @@ Status RandomSampler::ResetSampler() {
|
|||
}
|
||||
|
||||
void RandomSampler::Print(std::ostream &out, bool show_all) const {
|
||||
out << "(sampler): RandomSampler\n";
|
||||
|
||||
out << "\nSampler: RandomSampler";
|
||||
if (show_all) {
|
||||
out << "num_samples_: " << num_samples_ << '\n';
|
||||
out << "next_id_: " << next_id_ << '\n';
|
||||
// Call the super class for displaying any common detailed info
|
||||
Sampler::Print(out, show_all);
|
||||
// Then add our own info if any
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -80,11 +80,12 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t
|
|||
}
|
||||
|
||||
void Sampler::Print(std::ostream &out, bool show_all) const {
|
||||
out << "(sampler): base\n";
|
||||
|
||||
// Sampler printing is usually only called in the show_all mode.
|
||||
// Derived classes will display the name, then call back to this base
|
||||
// for common info.
|
||||
// No-op in the summary mode.
|
||||
if (show_all) {
|
||||
out << "num_rows_: " << num_rows_ << '\n';
|
||||
out << "num_samples_: " << num_samples_ << '\n';
|
||||
out << "\nnum_rows_: " << num_rows_ << "\nnum_samples_: " << num_samples_;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -89,7 +89,14 @@ Status SequentialSampler::ResetSampler() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
void SequentialSampler::Print(std::ostream &out, bool show_all) const { out << "(sampler): SequentialSampler\n"; }
|
||||
|
||||
void SequentialSampler::Print(std::ostream &out, bool show_all) const {
|
||||
out << "\nSampler: SequentialSampler";
|
||||
if (show_all) {
|
||||
// Call the super class for displaying any common detailed info
|
||||
Sampler::Print(out, show_all);
|
||||
// Then add our own info
|
||||
out << "\nStart index: " << start_index_;
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -49,6 +49,9 @@ class SequentialSampler : public Sampler {
|
|||
// @return - The error code return
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
// Printer for debugging purposes.
|
||||
// @param out - output stream to write to
|
||||
// @param show_all - bool to show detailed vs summary
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -119,5 +119,14 @@ Status SubsetRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffe
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void SubsetRandomSampler::Print(std::ostream &out, bool show_all) const {
|
||||
out << "\nSampler: SubsetRandomSampler";
|
||||
if (show_all) {
|
||||
// Call the super class for displaying any common detailed info
|
||||
Sampler::Print(out, show_all);
|
||||
// Then add our own info if any
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -51,6 +51,11 @@ class SubsetRandomSampler : public Sampler {
|
|||
// @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer.
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
// Printer for debugging purposes.
|
||||
// @param out - output stream to write to
|
||||
// @param show_all - bool to show detailed vs summary
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
private:
|
||||
// A list of indices (already randomized in constructor).
|
||||
std::vector<int64_t> indices_;
|
||||
|
|
|
@ -156,5 +156,14 @@ Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buf
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void WeightedRandomSampler::Print(std::ostream &out, bool show_all) const {
|
||||
out << "\nSampler: WeightedRandomSampler";
|
||||
if (show_all) {
|
||||
// Call the super class for displaying any common detailed info
|
||||
Sampler::Print(out, show_all);
|
||||
// Then add our own info if any
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -53,6 +53,11 @@ class WeightedRandomSampler : public Sampler {
|
|||
// @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer.
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
// Printer for debugging purposes.
|
||||
// @param out - output stream to write to
|
||||
// @param show_all - bool to show detailed vs summary
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
private:
|
||||
// A list of weights for each sample.
|
||||
std::vector<double> weights_;
|
||||
|
|
|
@ -33,7 +33,11 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
TextFileOp::Builder::Builder()
|
||||
: builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_shuffle_files_(false) {
|
||||
: builder_device_id_(0),
|
||||
builder_num_devices_(1),
|
||||
builder_total_rows_(0),
|
||||
builder_shuffle_files_(false),
|
||||
builder_sampler_(nullptr) {
|
||||
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
|
||||
builder_num_workers_ = config_manager->num_parallel_workers();
|
||||
builder_op_connector_size_ = config_manager->op_connector_size();
|
||||
|
@ -64,7 +68,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
|
|||
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
|
||||
builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_,
|
||||
std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_,
|
||||
builder_num_devices_, builder_device_id_);
|
||||
builder_num_devices_, builder_device_id_, std::move(builder_sampler_));
|
||||
RETURN_IF_NOT_OK(text_file_op->Init());
|
||||
*op = std::move(text_file_op);
|
||||
|
||||
|
@ -73,8 +77,9 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
|
|||
|
||||
TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
|
||||
std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list,
|
||||
int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id,
|
||||
std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
|
||||
device_id_(device_id),
|
||||
num_devices_(num_device),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <map>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "dataset/util/status.h"
|
||||
|
@ -112,6 +113,14 @@ class TextFileOp : public ParallelOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param std::shared_ptr<Sampler> sampler
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
|
||||
builder_sampler_ = std::move(sampler);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
int32_t builder_device_id_;
|
||||
int32_t builder_num_devices_;
|
||||
|
@ -123,6 +132,7 @@ class TextFileOp : public ParallelOp {
|
|||
std::vector<std::string> builder_text_files_list_;
|
||||
bool builder_shuffle_files_;
|
||||
std::unique_ptr<DataSchema> builder_schema_;
|
||||
std::shared_ptr<Sampler> builder_sampler_;
|
||||
};
|
||||
|
||||
// Constructor of TextFileOp
|
||||
|
@ -136,9 +146,10 @@ class TextFileOp : public ParallelOp {
|
|||
// @param columns_to_load - the names of the columns to load data from.
|
||||
// @param shuffle_files - whether or not to shuffle the files before reading data.
|
||||
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
|
||||
// @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes
|
||||
TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
|
||||
std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id);
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<Sampler> sampler);
|
||||
|
||||
// Default destructor
|
||||
~TextFileOp() = default;
|
||||
|
|
|
@ -48,7 +48,11 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
TFReaderOp::Builder::Builder()
|
||||
: builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_equal_rows_per_shard_(false) {
|
||||
: builder_device_id_(0),
|
||||
builder_num_devices_(1),
|
||||
builder_total_rows_(0),
|
||||
builder_equal_rows_per_shard_(false),
|
||||
builder_sampler_(nullptr) {
|
||||
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
|
||||
builder_num_workers_ = config_manager->num_parallel_workers();
|
||||
builder_worker_connector_size_ = config_manager->worker_connector_size();
|
||||
|
@ -87,11 +91,6 @@ Status TFReaderOp::Builder::ValidateInputs() const {
|
|||
err_msg += "Number of parallel workers is smaller or equal to 0\n";
|
||||
}
|
||||
|
||||
if (!builder_equal_rows_per_shard_ &&
|
||||
builder_dataset_files_list_.size() < static_cast<uint32_t>(builder_num_devices_)) {
|
||||
err_msg += "Not enough tfrecord files provided\n";
|
||||
}
|
||||
|
||||
if (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) {
|
||||
err_msg += "Wrong sharding configs\n";
|
||||
}
|
||||
|
@ -125,7 +124,8 @@ Status TFReaderOp::Builder::Build(std::shared_ptr<TFReaderOp> *out_tf_reader_op)
|
|||
std::shared_ptr<TFReaderOp> new_tf_reader_op = std::make_shared<TFReaderOp>(
|
||||
builder_num_workers_, builder_worker_connector_size_, builder_rows_per_buffer_, builder_total_rows_,
|
||||
builder_dataset_files_list_, std::move(builder_data_schema_), builder_op_connector_size_, builder_columns_to_load_,
|
||||
builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_);
|
||||
builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_,
|
||||
std::move(builder_sampler_));
|
||||
|
||||
RETURN_IF_NOT_OK(new_tf_reader_op->Init());
|
||||
*out_tf_reader_op = std::move(new_tf_reader_op);
|
||||
|
@ -136,8 +136,8 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64
|
|||
int64_t total_num_rows, std::vector<std::string> dataset_files_list,
|
||||
std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size,
|
||||
std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_device,
|
||||
int32_t device_id, bool equal_rows_per_shard)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
|
||||
device_id_(device_id),
|
||||
num_devices_(num_device),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
|
@ -1018,5 +1018,40 @@ Status TFReaderOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
||||
// their role.
|
||||
Status TFReaderOp::PrepareNodePostAction() {
|
||||
// Run common code from super class before adding TFReaderOp specific handling
|
||||
RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction());
|
||||
|
||||
// Specific handling for this op, we need to do cache op work so assign the sampler to the cache
|
||||
// TF is a special case because it can support file-based sharding/shuffling, or, if there
|
||||
// is a cache, then it can also do row-based sampler using the sampler on the cache.
|
||||
// Thus, pass true for random access op flag when saving the sampler. This is a special case,
|
||||
// since usually a non-mappable dataset would pass false here.
|
||||
RETURN_IF_NOT_OK(DatasetOp::SaveSamplerForCache(true));
|
||||
|
||||
// Now that the sampler has been saved for the cache, we need to adjust the TFReaderOp to turn it into
|
||||
// a simpler producer of all data (no shuffling or sharding or anything)
|
||||
if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) {
|
||||
device_id_ = 0;
|
||||
num_devices_ = 1;
|
||||
total_rows_ = 0;
|
||||
shuffle_files_ = false;
|
||||
equal_rows_per_shard_ = false;
|
||||
sampler_.reset(); // Normally SaveSampler code did this for us, but we passed in true above (See comment)
|
||||
} else {
|
||||
// This sanity check had been delayed until now in the prepare loop.
|
||||
// If we are not in a cache path, then we can validate the the file-based sharding config.
|
||||
// If we are in a cache path, there is no file-based sharding so the check is not correct in that
|
||||
// situation.
|
||||
if (!equal_rows_per_shard_ && dataset_files_list_.size() < static_cast<uint32_t>(num_devices_)) {
|
||||
RETURN_STATUS_UNEXPECTED("Not enough tfrecord files provided\n");
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -153,8 +153,17 @@ class TFReaderOp : public ParallelOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Setter method
|
||||
// @param std::shared_ptr<Sampler> sampler
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
|
||||
builder_sampler_ = std::move(sampler);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<DataSchema> builder_data_schema_;
|
||||
std::shared_ptr<Sampler> builder_sampler_;
|
||||
int32_t builder_device_id_;
|
||||
int32_t builder_num_devices_;
|
||||
int32_t builder_num_workers_;
|
||||
|
@ -180,10 +189,11 @@ class TFReaderOp : public ParallelOp {
|
|||
// @param columns_to_load - the names of the columns to load data from.
|
||||
// @param shuffle_files - whether or not to shuffle the files before reading data.
|
||||
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
|
||||
// @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes
|
||||
TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows,
|
||||
std::vector<std::string> dataset_files_list, std::unique_ptr<DataSchema> data_schema,
|
||||
int32_t op_connector_size, std::vector<std::string> columns_to_load, bool shuffle_files,
|
||||
int32_t num_devices, int32_t device_id, bool equal_rows_per_shard);
|
||||
int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<Sampler> sampler);
|
||||
|
||||
// Default destructor
|
||||
~TFReaderOp() = default;
|
||||
|
@ -236,6 +246,12 @@ class TFReaderOp : public ParallelOp {
|
|||
// @return Vector of the input file names
|
||||
std::vector<std::string> FileNames() { return dataset_files_list_; }
|
||||
|
||||
// During tree prepare phase, operators may have specific post-operations to perform depending on
|
||||
// their role.
|
||||
// @notes Derived versions of this function should always call it's superclass version first
|
||||
// before providing their own implementations.
|
||||
Status PrepareNodePostAction() override;
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
|
|
|
@ -88,7 +88,7 @@ Status VOCOp::Builder::SanityCheck() {
|
|||
VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path,
|
||||
const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer,
|
||||
int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
|
||||
: ParallelOp(num_workers, queue_size),
|
||||
: ParallelOp(num_workers, queue_size, std::move(sampler)),
|
||||
decode_(decode),
|
||||
row_cnt_(0),
|
||||
buf_cnt_(0),
|
||||
|
@ -97,7 +97,6 @@ VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std:
|
|||
folder_path_(folder_path),
|
||||
class_index_(class_index),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
sampler_(std::move(sampler)),
|
||||
data_schema_(std::move(data_schema)) {
|
||||
io_block_queues_.Init(num_workers_, queue_size);
|
||||
}
|
||||
|
|
|
@ -274,7 +274,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
TaskType task_type_;
|
||||
std::string task_mode_;
|
||||
int32_t rows_per_buffer_;
|
||||
std::shared_ptr<Sampler> sampler_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
|
||||
WaitPost wp_;
|
||||
|
|
|
@ -129,7 +129,7 @@ Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<D
|
|||
|
||||
Status TakeOp::PrepareNodePostAction() {
|
||||
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
|
||||
tree_->AddToRepeatStack(shared_from_this());
|
||||
tree_->AddToEOEOpStack(shared_from_this());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -88,13 +88,13 @@ Status ExecutionTree::AssignRoot(const std::shared_ptr<DatasetOp> &op) {
|
|||
}
|
||||
|
||||
// A print method typically used for debugging
|
||||
void ExecutionTree::Print(std::ostream &out) const {
|
||||
void ExecutionTree::Print(std::ostream &out, const std::shared_ptr<DatasetOp> &op) const {
|
||||
out << "Execution tree summary:\n"
|
||||
<< "-----------------------\n";
|
||||
this->PrintNode(out, root_, "", true, false);
|
||||
this->PrintNode(out, op == nullptr ? root_ : op, "", true, false);
|
||||
out << "\nExecution tree operator details:\n"
|
||||
<< "--------------------------------\n";
|
||||
this->PrintNode(out, root_, "", true, true);
|
||||
this->PrintNode(out, op == nullptr ? root_ : op, "", true, true);
|
||||
}
|
||||
|
||||
// A helper functions for doing the recursive printing
|
||||
|
@ -269,27 +269,40 @@ Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op)
|
|||
RETURN_IF_NOT_OK(this->PrepareNode(i));
|
||||
}
|
||||
|
||||
// Then clear the flags from this op now that we have prepared it.
|
||||
BitClear(&prepare_flags_, op_prep_flags);
|
||||
|
||||
// No more children, now we execute any prepare actions before going back up the
|
||||
// the tree on recursive function
|
||||
RETURN_IF_NOT_OK(dataset_op->PrepareNodePostAction());
|
||||
|
||||
// Then clear the flags from this op now that we have prepared it.
|
||||
BitClear(&prepare_flags_, op_prep_flags);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Adds an operator to the repeat stack during prepare phase.
|
||||
void ExecutionTree::AddToRepeatStack(std::shared_ptr<DatasetOp> dataset_op) { repeat_stack_.push(dataset_op); }
|
||||
// Adds an operator to the eoe operator stack during prepare phase.
|
||||
void ExecutionTree::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { eoe_stack_.push(dataset_op); }
|
||||
|
||||
// Pops an operator from the repeat stack during prepare phase.
|
||||
std::shared_ptr<DatasetOp> ExecutionTree::PopFromRepeatStack() {
|
||||
// Pops an operator from the eoe operator stack during prepare phase.
|
||||
std::shared_ptr<DatasetOp> ExecutionTree::PopFromEOEOpStack() {
|
||||
std::shared_ptr<DatasetOp> top_op = nullptr;
|
||||
if (!repeat_stack_.empty()) {
|
||||
top_op = repeat_stack_.top();
|
||||
repeat_stack_.pop();
|
||||
if (!eoe_stack_.empty()) {
|
||||
top_op = eoe_stack_.top();
|
||||
eoe_stack_.pop();
|
||||
}
|
||||
return top_op;
|
||||
}
|
||||
|
||||
// Adds a sampler to the sampler stack during prepare phase.
|
||||
void ExecutionTree::AddToSamplerStack(std::shared_ptr<Sampler> sampler) { sampler_stack_.push(sampler); }
|
||||
|
||||
// Pops an operator from the sampler stack during prepare phase.
|
||||
std::shared_ptr<Sampler> ExecutionTree::PopFromSamplerStack() {
|
||||
std::shared_ptr<Sampler> top_sampler = nullptr;
|
||||
if (!sampler_stack_.empty()) {
|
||||
top_sampler = sampler_stack_.top();
|
||||
sampler_stack_.pop();
|
||||
}
|
||||
return top_sampler;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,7 +37,8 @@ class ExecutionTree {
|
|||
// Prepare flags used during tree prepare phase
|
||||
enum PrepareFlags {
|
||||
kDePrepNone = 0,
|
||||
kDePrepRepeat = 1 // Processing a repeat operation
|
||||
kDePrepRepeat = 1, // Processing a repeat operation
|
||||
kDePrepCache = 2 // Processing a cache operation
|
||||
};
|
||||
|
||||
// State flags for the lifecycle of the tree
|
||||
|
@ -118,9 +119,9 @@ class ExecutionTree {
|
|||
// @return Status - The error code return
|
||||
Status Launch();
|
||||
|
||||
// A print method typically used for debugging
|
||||
// @param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const;
|
||||
/// A print method typically used for debugging
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out, const std::shared_ptr<DatasetOp> &op = nullptr) const;
|
||||
|
||||
// Returns an iterator positioned at the start
|
||||
// @return Iterator - The iterator
|
||||
|
@ -199,14 +200,23 @@ class ExecutionTree {
|
|||
// @return Status - The error code return
|
||||
Status PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op);
|
||||
|
||||
// Adds an operator to the repeat stack during prepare phase.
|
||||
// @param op - The dataset op to work add to repeat stack
|
||||
// @return Status - The error code return
|
||||
void AddToRepeatStack(std::shared_ptr<DatasetOp> dataset_op);
|
||||
/// Adds an operator to the eoe operator stack during prepare phase.
|
||||
/// \param op - The dataset op to work add to eoe stack
|
||||
/// \return Status - The error code return
|
||||
void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op);
|
||||
|
||||
// Pops an operator from the repeat stack during prepare phase.
|
||||
// @return shared_ptr to the popped operator
|
||||
std::shared_ptr<DatasetOp> PopFromRepeatStack();
|
||||
/// Pops an operator from the eoe operator stack during prepare phase.
|
||||
/// \return shared_ptr to the popped operator
|
||||
std::shared_ptr<DatasetOp> PopFromEOEOpStack();
|
||||
|
||||
/// Adds a sampler to the sampler stack during prepare phase.
|
||||
/// \param samplerop - The dataset op to work add to eoe stack
|
||||
/// \return Status - The error code return
|
||||
void AddToSamplerStack(std::shared_ptr<Sampler> sampler);
|
||||
|
||||
/// Pops an operator from the sampler stack during prepare phase.
|
||||
/// \return shared_ptr to the popped operator
|
||||
std::shared_ptr<Sampler> PopFromSamplerStack();
|
||||
|
||||
// Return the pointer to the TaskGroup
|
||||
// @return raw pointer to the TaskGroup
|
||||
|
@ -236,9 +246,10 @@ class ExecutionTree {
|
|||
int32_t id_count_; // Counter for generating operator id's
|
||||
uint32_t prepare_flags_; // Flags used during tree prepare
|
||||
TreeState tree_state_; // Tracking the current tree state
|
||||
std::stack<std::shared_ptr<DatasetOp>> repeat_stack_; // A stack used during prepare phase
|
||||
std::unique_ptr<Monitor> perf_monitor_; // Performance Monitor
|
||||
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
|
||||
std::stack<std::shared_ptr<DatasetOp>> eoe_stack_; // A stack used during prepare phase
|
||||
std::stack<std::shared_ptr<Sampler>> sampler_stack_; // A stack used during prepare phase
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue