Remove Repeat(1),Take(-1), and Skip(0) in NodeRemovalPass

This commit is contained in:
Nat Sutyanyong 2020-12-03 20:51:02 -05:00
parent 3280474d71
commit 4cb78f2e03
80 changed files with 527 additions and 426 deletions

View File

@ -490,12 +490,6 @@ RenameDataset::RenameDataset(std::shared_ptr<Dataset> input, const std::vector<s
#endif
RepeatDataset::RepeatDataset(std::shared_ptr<Dataset> input, int32_t count) {
// Workaround for repeat == 1, do not inject repeat.
if (count == 1) {
ir_node_ = input->IRNode();
return;
}
auto ds = std::make_shared<RepeatNode>(input->IRNode(), count);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
@ -516,13 +510,6 @@ SkipDataset::SkipDataset(std::shared_ptr<Dataset> input, int32_t count) {
}
TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) {
// If count is greater than the number of element in dataset or equal to -1,
// all the element in dataset will be taken
if (count == -1) {
ir_node_ = input->IRNode();
return;
}
auto ds = std::make_shared<TakeNode>(input->IRNode(), count);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);

View File

@ -66,7 +66,7 @@ class ColDescriptor {
/// an unknown dimension, then the output shape returned shall resolve dimensions as needed.
/// \param[in] num_elements - The number of elements in the data for a Tensor
/// \param[inout] out_shape - The materialized output Tensor shape
/// \return Status - The error code return
/// \return Status The status code returned
Status MaterializeTensorShape(int32_t num_elements, TensorShape *out_shape) const;
/// \brief << Stream output operator overload
@ -124,13 +124,13 @@ class DataSchema {
/// \brief Parses a schema json file and populates the columns and meta info.
/// \param[in] schema_file_path - the schema file that has the column's info to load
/// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns.
/// \return Status - The error code return
/// \return Status The status code returned
Status LoadSchemaFile(const std::string &schema_file_path, const std::vector<std::string> &columns_to_load);
/// \brief Parses a schema JSON string and populates the columns and meta info.
/// \param[in] schema_json_string - the schema file that has the column's info to load
/// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns.
/// \return Status - The error code return
/// \return Status The status code returned
Status LoadSchemaString(const std::string &schema_json_string, const std::vector<std::string> &columns_to_load);
/// \brief A print method typically used for debugging
@ -148,7 +148,7 @@ class DataSchema {
/// \brief Adds a column descriptor to the schema
/// \param[in] cd - The ColDescriptor to add
/// \return Status - The error code return
/// \return Status The status code returned
Status AddColumn(const ColDescriptor &cd);
/// \brief getter
@ -169,7 +169,7 @@ class DataSchema {
/// \brief Loops through all columns in the schema and returns a map with the column name to column index number.
/// \param[inout] out_column_name_map - The output map of columns names to column index
/// \return Status - The error code return
/// \return Status The status code returned
Status GetColumnNameMap(std::unordered_map<std::string, int32_t> *out_column_name_map);
private:
@ -177,7 +177,7 @@ class DataSchema {
/// does not follow any particular order (json standard does not enforce any ordering protocol).
/// This one produces a schema that contains all of the columns from the schema file.
/// \param[in] column_tree - The nlohmann tree from the json file to parse
/// \return Status - The error code return
/// \return Status The status code returned
Status AnyOrderLoad(nlohmann::json column_tree);
/// \brief Internal helper function. For each input column name, perform a lookup to the json document to
@ -185,18 +185,18 @@ class DataSchema {
/// descriptor and add to the schema in the order in which the input column names are given.
/// \param[in] column_tree - The nlohmann tree from the json file to parse
/// \param[in] columns_to_load - list of strings for the columns to add to the schema
/// \return Status - The error code return
/// \return Status The status code returned
Status ColumnOrderLoad(nlohmann::json column_tree, const std::vector<std::string> &columns_to_load);
/// \brief Internal helper function. Given the json tree for a given column, load it into our schema.
/// \param[in] columnTree - The nlohmann child tree for a given column to load.
/// \param[in] col_name - The string name of the column for that subtree.
/// \return Status - The error code return
/// \return Status The status code returned
Status ColumnLoad(nlohmann::json column_child_tree, const std::string &col_name);
/// \brief Internal helper function. Performs sanity checks on the json file setup.
/// \param[in] js - The nlohmann tree for the schema file
/// \return Status - The error code return
/// \return Status The status code returned
Status PreLoadExceptionCheck(const nlohmann::json &js);
std::vector<ColDescriptor> col_descs_; // Vector of column descriptors

View File

@ -53,7 +53,7 @@ class IteratorBase {
// functionality exists in the derived versions of this function.
// @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data
// messages are encountered (such as eoe or eof), then an empty TensorRow is returned back.
// @return Status - The error code return
// @return Status The status code returned
// @note The position of a Tensor/column might be different from the initial column order
// in corresponding Dataset Op. User must be aware that MapOp, ZipOps, and others might change
// the column ordering.
@ -97,17 +97,17 @@ class DatasetIterator : public IteratorBase {
// from the tree root node directly.
// @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data
// messages are encountered (such as eoe or eof), then an empty TensorRow is returned back.
// @return Status - The error code return
// @return Status The status code returned
Status FetchNextTensorRow(TensorRow *out_row) override;
// Fetches the next tensor row into device row, and returns it's shape.
// @param out_shapes - A vector of tensor shapes (one shape per column)
// @return Status - The error code return
// @return Status The status code returned
Status GetOutputShapes(std::vector<TensorShape> *out_shapes);
// Fetches the next tensor row into device row, and returns it's shape.
// @param outShapes - A vector of tensor shapes (one shape per column)
// @return Status - The error code return
// @return Status The status code returned
Status GetOutputTypes(std::vector<DataType> *out_types);
// Getter
@ -140,12 +140,12 @@ class ChildIterator : public IteratorBase {
// only from the child/worker id as given from the constructor.
// @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data
// messages are encountered (such as eoe or eof), then an empty TensorRow is returned back.
// @return Status - The error code return
// @return Status The status code returned
Status FetchNextTensorRow(TensorRow *out_row) override;
// This function drains buffer until next eoe has been received.
// It will be a no-op if the previous row returned is empty.
// @return Status - The error code return
// @return Status The status code returned
Status Drain();
// Getter

View File

@ -134,7 +134,7 @@ class BarrierOp : public PipelineOp {
// Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override;
// Handles preprocessing of the main loop, used when starting new epoch

View File

@ -112,12 +112,12 @@ class BatchOp : public ParallelOp {
#endif
// @param std::shared_ptr<BatchOp> *ptr pointer to shared_ptr, actual return arg
// @return Status - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<BatchOp> *);
private:
// Sanity check for builder class args
// @return Status - The error code return
// @return Status The status code returned
Status SanityCheck();
bool builder_drop_;
@ -167,11 +167,11 @@ class BatchOp : public ParallelOp {
~BatchOp() {}
// @param int32_t workerId
// @return Status - The error code return
// @return Status The status code returned
Status EofReceived(int32_t) override;
// @param int32_t workerId
// @return Status - The error code return
// @return Status The status code returned
Status EoeReceived(int32_t) override;
// A print method typically used for debugging
@ -190,7 +190,7 @@ class BatchOp : public ParallelOp {
}
// Main loop of batch
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override;
// Base-class override for NodePass visitor acceptor.
@ -214,14 +214,14 @@ class BatchOp : public ParallelOp {
// @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows
// @param int32_t size - batch_size
// @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
// @return Status - The error code return
// @return Status The status code returned
static Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest,
dsize_t batch_size);
// @param table
// @param const PadInfo &pad_info pad info
// @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
// @return Status - The error code return
// @return Status The status code returned
static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info,
const std::unordered_map<std::string, int32_t> &column_name_id_map);
@ -233,18 +233,18 @@ class BatchOp : public ParallelOp {
private:
// Worker thread for doing the memcpy of batch
// @param int32_t param workerId
// @return Status - The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override;
// Generate buffer with batched tensors
// @return Status - The error code return
// @return Status The status code returned
Status MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair,
std::unique_ptr<DataBuffer> *db);
#ifdef ENABLE_PYTHON
// Function that calls pyfunc to perform map on batch
// @param (std::pair<std::unique_ptr<TensorQTable>, batch_stats> *table_pair - contains un-batched tensor
// @return Status - The error code return
// @return Status The status code returned
Status MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair);
#endif
@ -253,7 +253,7 @@ class BatchOp : public ParallelOp {
// @param std::set<int32_t> *cols, col ids to perform pad on
// @param std::vector<float> *vals, default padding value for each column
// @param std::vector<std::vector<dsize_t>> *shapes, padding shape specified by user
// @return Status - The error code return
// @return Status The status code returned
static Status UnpackPadInfo(const PadInfo &pad_info,
const std::unordered_map<std::string, int32_t> &column_name_id_map,
std::set<int32_t> *pad_cols, std::vector<std::shared_ptr<Tensor>> *pad_vals,
@ -264,20 +264,20 @@ class BatchOp : public ParallelOp {
int32_t num_consumers() const override { return 1; }
// get the batch size for next batch
// @return Status - The error code return
// @return Status The status code returned
Status GetBatchSize(int32_t *batch_size, CBatchInfo info);
// Do the initialization of all queues then start all worker threads
// @return Status - The error code return
// @return Status The status code returned
Status LaunchThreadsAndInitOp();
#ifdef ENABLE_PYTHON
// Invoke batch size function with current BatchInfo to generate batch size.
// @return Status - The error code return
// @return Status The status code returned
Status InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info);
// Invoke batch map function with current BatchInfo to generate tensors to batch.
// @return Status - The error code return
// @return Status The status code returned
Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info);
#endif

View File

@ -107,7 +107,7 @@ class BucketBatchByLengthOp : public PipelineOp {
// Might need to batch remaining buckets after receiving eoe, so override this method.
// @param int32_t workerId
// @return Status - The error code returned
// @return Status The status code returned
Status EoeReceived(int32_t) override;
std::string Name() const override { return kBucketBatchByLengthOp; }
@ -123,7 +123,7 @@ class BucketBatchByLengthOp : public PipelineOp {
}
// Main loop of batch
// @return Status - The error code returned
// @return Status The status code returned
Status operator()() override;
private:

View File

@ -104,7 +104,7 @@ class BuildSentencePieceVocabOp : public PipelineOp {
// The builder "build" method creates the final object.
// @param std::shared_ptr<BuildVocabOp> *op - DatasetOp
// @return - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<BuildSentencePieceVocabOp> *op);
private:

View File

@ -110,7 +110,7 @@ class BuildVocabOp : public ParallelOp {
// The builder "build" method creates the final object.
// @param std::shared_ptr<BuildVocabOp> *op - DatasetOp
// @return - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<BuildVocabOp> *op);
private:

View File

@ -53,7 +53,7 @@ class CacheBase : public ParallelOp {
/// \brief Overrides base class reset method. When an operator does a reset, it cleans up any state
/// info from it's previous execution and then initializes itself so that it can be executed
/// again.
/// \return Status - The error code return
/// \return Status The status code returned
Status Reset() override;
/// \brief A print method typically used for debugging

View File

@ -80,7 +80,7 @@ class CacheLookupOp : public CacheBase, public SamplerRT {
std::shared_ptr<SamplerRT> build_sampler_;
// Check if the required parameters are set by the builder.
// \return Status The error code return
// \return Status The status code returned
Status SanityCheck() const;
};
/// \brief Constructor

View File

@ -136,7 +136,7 @@ class CacheMergeOp : public ParallelOp {
std::shared_ptr<SamplerRT> build_sampler_;
/// Check if the required parameters are set by the builder.
/// \return Status The error code return
/// \return Status The status code returned
Status SanityCheck() const;
};
@ -189,7 +189,7 @@ class CacheMergeOp : public ParallelOp {
/// \brief Base-class override for handling cases when an eof is received.
/// \param worker_id - The worker id
/// \return Status - The error code return
/// \return Status The status code returned
Status EofReceived(int32_t worker_id) override;
protected:

View File

@ -99,7 +99,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
std::shared_ptr<SamplerRT> build_sampler_;
/// \brief Check if the required parameters are set by the builder.
/// \return Status The error code return
/// \return Status The status code returned
Status SanityCheck() const;
};
@ -119,7 +119,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
/// \brief Base-class override for special eoe handler.
/// CacheOp must override this because it shall not perform default handling of eoe. Instead
/// the CacheOp manages actions related to the end of the epoch.
/// \return Status - The error code return
/// \return Status The status code returned
Status EoeReceived(int32_t worker_id) override;
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
@ -133,7 +133,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for handling cases when an eof is received.
/// \param worker_id - The worker id
/// \return Status - The error code return
/// \return Status The status code returned
Status EofReceived(int32_t worker_id) override;
Status operator()() override;
Status WorkerEntry(int32_t worker_id) override;
@ -159,7 +159,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
Status CacheAllRows(int32_t worker_id);
Status RegisterResources() override;
/// \brief Private function for cache setup/init work just after construction
/// \return Status The error code return
/// \return Status The status code returned
Status InitCache();
};
} // namespace dataset

View File

@ -94,7 +94,7 @@ class ConcatOp : public PipelineOp {
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override;
// Op name getter

View File

@ -146,14 +146,14 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// DatasetOps operate by launching a thread (see ExecutionTree).
/// This pure virtual version makes the requirement that derived classes must provide a functor
/// that will execute their main runtime loop code.
/// \return Status - The error code return
/// \return Status The status code returned
virtual Status operator()() = 0;
/// \brief Gets the next buffer from the given child
/// \notes See GetNextInput for similar function that has built-in message handling
/// \param p_buffer - The shared pointer for the fetched buffer to return (by reference)
/// \param worker_id - The worker id
/// \return Status - The error code return
/// \return Status The status code returned
virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id) {
return GetNextBuffer(p_buffer, worker_id, false);
}
@ -161,7 +161,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \brief Gets the next buffer from the given child
/// \notes See GetNextInput for similar function that has built-in message handling
/// \param p_buffer - The shared pointer for the fetched buffer to return (by reference)
/// \return Status - The error code return
/// \return Status The status code returned
virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer) { return GetNextBuffer(p_buffer, 0, false); }
/// \brief Gets the next buffer from the given child
@ -169,7 +169,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \param p_buffer - The shared pointer for the fetched buffer to return (by reference)
/// \param worker_id - The worker id
/// \param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
/// \return Status - The error code return
/// \return Status The status code returned
virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe);
/// \brief Gets the next buffer from the given child . This function also has built-in eoe and eof
@ -177,7 +177,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// those messages are received.
/// \param p_buffer - The shared pointer for the fetched buffer to return (by reference)
/// \param worker_id - The worker id
/// \return Status - The error code return
/// \return Status The status code returned
Status GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id = 0, int32_t child_index = 0);
/// \brief Gets the batch size
@ -200,19 +200,19 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// The base class implementation simply flows the eoe message to output. Derived classes
/// may override if they need to perform special eoe handling.
/// \param worker_id - The worker id
/// \return Status - The error code return
/// \return Status The status code returned
virtual Status EoeReceived(int32_t worker_id);
/// \brief Performs handling for when an eof message is received.
/// The base class implementation simply flows the eof message to output. Derived classes
/// may override if they need to perform special eof handling.
/// \param worker_id - The worker id
/// \return Status - The error code return
/// \return Status The status code returned
virtual Status EofReceived(int32_t worker_id);
/// \brief Derived classes may implement the reset function if the operator is stateful and needs
/// specific reset handling that is not contained in this common code version of the reset
/// \return Status - The error code return
/// \return Status The status code returned
virtual Status Reset();
/// \brief During tree prepare phase, operators may have specific pre-operations to perform depending on

View File

@ -79,7 +79,7 @@ class FilterOp : public ParallelOp {
private:
// Sanity check for builder class args.
// @return Status - The error code return.
// @return Status The status code returned.
Status SanityCheck();
std::vector<std::string> build_in_col_names_;
std::shared_ptr<TensorOp> builder_predicate_func_;
@ -105,15 +105,15 @@ class FilterOp : public ParallelOp {
// Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree),This class functor will
// provide the master loop that drives the logic for performing the work.
// @return Status The error code return
// @return Status The status code returned
Status operator()() override;
// @param int32_t workerId.
// @return Status - The error code return.
// @return Status The status code returned.
Status EofReceived(int32_t) override;
// @param int32_t workerId.
// @return Status - The error code return.
// @return Status The status code returned.
Status EoeReceived(int32_t) override;
// A print method typically used for debugging.
@ -151,34 +151,34 @@ class FilterOp : public ParallelOp {
// logic of FilterOp, getting the data from previous Op, validating user specified column names,
// applying predicate to each of the data, filter the data when predicate result is false.
// @param worker_id The id assigned to this thread/worker upon creation.
// @return Status The error code return.
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_
// Filter the data by predicate function .
// @param in_buffer input data buffer.
// @param to_proess_indices Indices of columns to be processed.
// @param out data buffer that are filtered by predicate.
// @return Status The error code return.
// @return Status The status code returned
Status WorkerCompute(DataBuffer *in_buffer, std::unique_ptr<TensorQTable> *out);
// Collector databuffer.
// @return Status The error code return.
// @return Status The status code returned
Status Collector();
// @param input tensor vector.
// @return Status - The error code return.
// @return Status The status code returned.
Status CheckInput(const TensorRow &input) const;
// Invoke python func.
// @param input tensor vector.
// @param the result of predicate.
// @return Status - The error code return.
// @return Status The status code returned.
Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate);
// Private function for validating if each of the user specified input column names
// exist in the DataBuffer.
// @param input_columns The vector of input column names used in the current thread.
// @return Status The error code return.
// @return Status The status code returned
Status ValidateInColumns(const std::vector<std::string> *input_columns);
// Private function for checking the column legality

View File

@ -133,7 +133,7 @@ class MapOp : public ParallelOp {
int32_t build_op_connector_size_;
// Check if the required parameters are set by the builder.
// @return Status The error code return
// @return Status The status code returned
Status sanityCheck() const;
};
@ -170,7 +170,7 @@ class MapOp : public ParallelOp {
// provide the master loop that drives the logic for performing the work
// This main thread creates local queues, pulls databuffers from the previous
// op's Connector and distributes them to the local queues. Workers pull from the local queues.
// @return Status The error code return
// @return Status The status code returned
Status operator()() override;
// Getter
@ -239,7 +239,7 @@ class MapOp : public ParallelOp {
// applying a list of TensorOps to each of the data, process the results and then
// pushing them back to MapOp's output Connector to be fetched by the next Op.
// @param worker_id The id assigned to this thread/worker upon creation.
// @return Status The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_
// Private function for worker thread to perform TensorOp's compute function and get the result.

View File

@ -89,7 +89,7 @@ class ParallelOp : public DatasetOp {
}
// Override base class reset to provide reset actions specific to the ParallelOp class.
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override;
// Getter
@ -115,7 +115,7 @@ class ParallelOp : public DatasetOp {
protected:
// Interface for derived classes to implement. All derived classes must provide the entry
// function with the main execution loop for worker threads.
// @return Status - The error code return
// @return Status The status code returned
virtual Status WorkerEntry(int32_t workerId) = 0;
/// This function is only intended to be called by CallbackManager within the master thread of ParallelOp

View File

@ -75,7 +75,7 @@ class ProjectOp : public PipelineOp {
// However, the ProjectOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
// @return Status - The error code returned.
// @return Status The status code returned
Status operator()() override;
// Gets a buffer from the child node and projects that buffer. The caller is typically our parent node.
@ -93,12 +93,12 @@ class ProjectOp : public PipelineOp {
// Base-class override for special eoe handler.
// Inline operators must override this because there is no connector to push eoe onto.
// @return Status - The error code returned.
// @return Status The status code returned
Status EoeReceived(int32_t worker_id) override;
// Base-class override for special eof handler.
// Inline operators must override this because there is no connector to push eof onto.
// @return Status - The error code returned.
// @return Status The status code returned
Status EofReceived(int32_t worker_id) override;
// Base-class override for NodePass visitor acceptor.

View File

@ -107,7 +107,7 @@ class RenameOp : public PipelineOp {
// Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override;
// Base-class override for NodePass visitor acceptor.

View File

@ -78,7 +78,7 @@ class RepeatOp : public PipelineOp {
// However, the RepeatOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override;
// This function returns the buffer that is at the top of our output connector. The caller is
@ -90,7 +90,7 @@ class RepeatOp : public PipelineOp {
// @param p_buffer - output pointer to the buffer that it will fetch.
// @param worker_id - The worker id
// @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
// @return Status - The error code return
// @return Status The status code returned
Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override;
// Base-class override for handling cases when an eoe is received.
@ -130,7 +130,7 @@ class RepeatOp : public PipelineOp {
int32_t num_repeats() { return num_repeats_; }
/// \brief reset Op
/// \@return Status - The error code return
/// \@return Status The status code returned
Status Reset() override;
int64_t GetTreeRepeatCount() override;

View File

@ -146,13 +146,13 @@ class ShuffleOp : public PipelineOp {
// Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override;
// Base-class override for special eoe handler.
// ShuffleOp must override this because it shall not perform default handling of eoe. Instead
// the ShuffleOp needs to manage actions related to the end of the epoch itself.
// @return Status - The error code return
// @return Status The status code returned
Status EoeReceived(int32_t worker_id) override;
// Base-class override for NodePass visitor acceptor.
@ -167,17 +167,17 @@ class ShuffleOp : public PipelineOp {
private:
// Private function to add a new row to the shuffle buffer.
// @return Status - The error code return
// @return Status The status code returned
Status AddRowToShuffleBuffer(TensorRow new_shuffle_row);
// Private function to populate the shuffle buffer initially by fetching from the child output
// connector until the shuffle buffer is full (or there is no more data coming).
// @return Status - The error code return
// @return Status The status code returned
Status InitShuffleBuffer();
// Private function to re-init the shuffle op for another epoch. Shuffle op calls this by
// itself rather than waiting for the reset driven from operators above it in the pipeline.
// @return Status - The error code return
// @return Status The status code returned
Status SelfReset();
int32_t shuffle_size_; // User config for the size of the shuffle buffer (number of rows)

View File

@ -63,7 +63,7 @@ class SkipOp : public PipelineOp {
// Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override;
// Base-class override for handling cases when an eoe is received.

View File

@ -130,12 +130,12 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
}
/// \brief Check validity of input args
/// \return - The error code returned
/// \return Status The status code returned
Status SanityCheck();
/// \brief The builder "build" method creates the final object.
/// \param[inout] std::shared_ptr<AlbumOp> *op - DatasetOp
/// \return - The error code returned
/// \return Status The status code returned
Status Build(std::shared_ptr<AlbumOp> *op);
private:
@ -168,18 +168,18 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
~AlbumOp() = default;
/// \brief Initialize AlbumOp related var, calls the function to walk all files
/// \return - The error code returned
/// \return Status The status code returned
Status PrescanEntry();
/// \brief Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
/// \param[in] int32_t workerId - id of each worker
/// \return Status - The error code returned
/// \return Status The status code returned
Status WorkerEntry(int32_t worker_id) override;
/// \brief Main Loop of AlbumOp
/// Master thread: Fill IOBlockQueue, then goes to sleep
/// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
/// \return Status - The error code returned
/// \return Status The status code returned
Status operator()() override;
/// \brief A print method typically used for debugging
@ -204,93 +204,93 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
private:
/// \brief Initialize Sampler, calls sampler->Init() within
/// \return Status The error code returned
/// \return Status The status code returned
Status InitSampler();
/// \brief Load image to tensor row
/// \param[in] image_file Image name of file
/// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadImageTensor(const std::string &image_file, uint32_t col_num, TensorRow *row);
/// \brief Load vector of ints to tensor, append tensor to tensor row
/// \param[in] json_obj Json object containing multi-dimensional label
/// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadIntArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
/// \brief Load vector of floatss to tensor, append tensor to tensor row
/// \param[in] json_obj Json object containing array data
/// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadFloatArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
/// \brief Load string array into a tensor, append tensor to tensor row
/// \param[in] json_obj Json object containing string tensor
/// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadStringArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
/// \brief Load string into a tensor, append tensor to tensor row
/// \param[in] json_obj Json object containing string tensor
/// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadStringTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
/// \brief Load float value to tensor row
/// \param[in] json_obj Json object containing float
/// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadFloatTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
/// \brief Load int value to tensor row
/// \param[in] json_obj Json object containing int
/// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadIntTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
/// \brief Load emtpy tensor to tensor row
/// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadEmptyTensor(uint32_t col_num, TensorRow *row);
/// \brief Load id from file name to tensor row
/// \param[in] file The file name to get ID from
/// \param[in] col_num Column num in schema
/// \param[inout] row Tensor row to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadIDTensor(const std::string &file, uint32_t col_num, TensorRow *row);
/// \brief Load a tensor row according to a json file
/// \param[in] row_id_type row_id - id for this tensor row
/// \param[in] ImageColumns file Json file location
/// \param[inout] TensorRow row Json content stored into a tensor row
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadTensorRow(row_id_type row_id, const std::string &file, TensorRow *row);
/// \param[in] const std::vector<int64_t> &keys Keys in ioblock
/// \param[inout] std::unique_ptr<DataBuffer> db Databuffer to push to
/// \return Status The error code returned
/// \return Status The status code returned
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);
/// \brief Called first when function is called
/// \return Status The error code returned
/// \return Status The status code returned
Status LaunchThreadsAndInitOp();
/// \brief reset Op
/// \return Status The error code return
/// \return Status The status code returned
Status Reset() override;
// Private function for computing the assignment of the column name map.
// @return Status The error code returned
// @return Status The status code returned
Status ComputeColMap() override;
int32_t rows_per_buffer_;

View File

@ -116,12 +116,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
return *this;
}
// Check validity of input args
// @return - The error code return
// @return Status The status code returned
Status SanityCheck();
// The builder "build" method creates the final object.
// @param std::shared_ptr<CelebAOp> *op - DatasetOp
// @return - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<CelebAOp> *op);
private:
@ -151,12 +151,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// Main Loop of CelebAOp
// Master thread: Fill IOBlockQueue, then goes to sleep
// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override;
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param int32_t worker_id - id of each worker
// @return Status - The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override;
// A print method typically used for debugging
@ -166,7 +166,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// Method in operator(), to fill IOBlockQueue
// @param std::unique_ptr<DataBuffer> sampler_buffer - to fill IOBlockQueue
// @return Status - The error code return
// @return Status The status code returned
Status AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer);
/// \brief Base-class override for NodePass visitor acceptor
@ -199,14 +199,14 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// @param const std::vector<int64_t> &keys - keys in ioblock
// @param std::unique_ptr<DataBuffer> db
// @return Status - The error code return
// @return Status The status code returned
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);
// Load a tensor row according to a pair
// @param row_id_type row_id - id for this tensor row
// @param std::pair - <image_file,<label>>
// @param TensorRow row - image & label read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadTensorRow(row_id_type row_id, const std::pair<std::string, std::vector<int32_t>> &image_label,
TensorRow *row);
@ -215,7 +215,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
bool CheckDatasetTypeValid();
// reset Op
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override;
// Private function for computing the assignment of the column name map.

View File

@ -109,12 +109,12 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
}
// Check validity of input args
// @return - The error code return
// @return Status The status code returned
Status SanityCheck();
// The builder "build" method creates the final object.
// @param std::shared_ptr<CifarOp> *op - DatasetOp
// @return - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<CifarOp> *op);
private:
@ -144,13 +144,13 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param uint32_t workerId - id of each worker
// @return Status - The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override;
// Main Loop of CifarOp
// Master thread: Fill IOBlockQueue, then goes to sleep
// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override;
// A print method typically used for debugging
@ -177,18 +177,18 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return
// @return Status The status code returned
Status InitSampler();
// Load a tensor row according to a pair
// @param uint64_t index - index need to load
// @param TensorRow row - image & label read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadTensorRow(uint64_t index, TensorRow *row);
// @param const std::vector<uint64_t> &keys - keys in ioblock
// @param std::unique_ptr<DataBuffer> db
// @return Status - The error code return
// @return Status The status code returned
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);
// Read block data from cifar file
@ -200,7 +200,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
Status LaunchThreadsAndInitOp();
// reset Op
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override;
// Get cifar files in dir
@ -221,7 +221,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// Method derived from RandomAccess Op, enable Sampler to get all ids for each calss
// @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class
// @return Status - The error code return
// @return Status The status code returned
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
// Private function for computing the assignment of the column name map.

View File

@ -133,12 +133,12 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
}
// Check validity of input args
// @return = The error code return
// @return Status The status code returned
Status SanityCheck();
// The builder "Build" method creates the final object.
// @param std::shared_ptr<CocoOp> *op - DatasetOp
// @return - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<CocoOp> *op);
private:
@ -173,13 +173,13 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param int32_t workerId - id of each worker
// @return Status - The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override;
// Main Loop of CocoOp
// Master thread: Fill IOBlockQueue, then goes to sleep
// Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override;
// A print method typically used for debugging
@ -214,19 +214,19 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
std::string Name() const override { return "CocoOp"; }
/// \brief Gets the class indexing
/// \return Status - The status code return
/// \return Status The status code returned
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override;
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return
// @return Status The status code returned
Status InitSampler();
// Load a tensor row according to image id
// @param row_id_type row_id - id for this tensor row
// @param std::string image_id - image id
// @param TensorRow row - image & target read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row);
// Load a tensor row with vector which a vector to a tensor
@ -235,7 +235,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
// @param std::shared_ptr<Tensor> image - image tensor
// @param std::shared_ptr<Tensor> coordinate - coordinate tensor
// @param TensorRow row - image & target read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadDetectionTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image,
std::shared_ptr<Tensor> coordinate, TensorRow *trow);
@ -245,7 +245,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
// @param std::shared_ptr<Tensor> image - image tensor
// @param std::shared_ptr<Tensor> coordinate - coordinate tensor
// @param TensorRow row - image & target read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadSimpleTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image,
std::shared_ptr<Tensor> coordinate, TensorRow *trow);
@ -255,69 +255,69 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
// @param std::shared_ptr<Tensor> image - image tensor
// @param std::shared_ptr<Tensor> coordinate - coordinate tensor
// @param TensorRow row - image & target read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadMixTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image,
std::shared_ptr<Tensor> coordinate, TensorRow *trow);
// @param const std::string &path - path to the image file
// @param const ColDescriptor &col - contains tensor implementation and datatype
// @param std::shared_ptr<Tensor> tensor - return
// @return Status - The error code return
// @return Status The status code returned
Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor);
// @param const std::vector<uint64_t> &keys - keys in ioblock
// @param std::unique_ptr<DataBuffer> db
// @return Status - The error code return
// @return Status The status code returned
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);
// Read annotation from Annotation folder
// @return Status - The error code return
// @return Status The status code returned
Status ParseAnnotationIds();
// @param const std::shared_ptr<Tensor> &sample_ids - sample ids of tensor
// @param std::vector<int64_t> *keys - image id
// @return Status - The error code return
// @return Status The status code returned
Status TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys);
// Called first when function is called
// @return Status - The error code return
// @return Status The status code returned
Status LaunchThreadsAndInitOp();
// Reset dataset state
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override;
// @param nlohmann::json image_tree - image tree of json
// @param std::vector<std::string> *image_vec - image id list of json
// @return Status - The error code return
// @return Status The status code returned
Status ImageColumnLoad(const nlohmann::json &image_tree, std::vector<std::string> *image_vec);
// @param nlohmann::json categories_tree - categories tree of json
// return Status - The error code return
// @return Status The status code returned
Status CategoriesColumnLoad(const nlohmann::json &categories_tree);
// @param nlohmann::json categories_tree - categories tree of json
// @param const std::string &image_file - current image name in annotation
// @param const int32_t &id - current unique id of annotation
// @return Status - The error code return
// @return Status The status code returned
Status DetectionColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id);
// @param nlohmann::json categories_tree - categories tree of json
// @param const std::string &image_file - current image name in annotation
// @param const int32_t &id - current unique id of annotation
// @return Status - The error code return
// @return Status The status code returned
Status StuffColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id);
// @param nlohmann::json categories_tree - categories tree of json
// @param const std::string &image_file - current image name in annotation
// @param const int32_t &id - current unique id of annotation
// @return Status - The error code return
// @return Status The status code returned
Status KeypointColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id);
// @param nlohmann::json categories_tree - categories tree of json
// @param const std::string &image_file - current image name in annotation
// @param const int32_t &image_id - current unique id of annotation
// @return Status - The error code return
// @return Status The status code returned
Status PanopticColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file,
const int32_t &image_id);

View File

@ -115,13 +115,13 @@ class GeneratorOp : public PipelineOp {
// Class functor operator () override.
// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work.
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override;
// Overrides base class reset method. When an operator does a reset, it cleans up any state
// info from it's previous execution and then initializes itself so that it can be executed
// again.
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override;
// Base-class override for NodePass visitor acceptor.

View File

@ -135,12 +135,12 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
}
// Check validity of input args
// @return - The error code return
// @return Status The status code returned
Status SanityCheck();
// The builder "build" method creates the final object.
// @param std::shared_ptr<ImageFolderOp> *op - DatasetOp
// @return - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<ImageFolderOp> *op);
private:
@ -172,28 +172,28 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
// Initialize ImageFOlderOp related var, calls the function to walk all files
// @param - std::string dir file directory to ImageNetFolder
// @return - The error code return
// @return Status The status code returned
Status PrescanMasterEntry(const std::string &dir);
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param int32_t workerId - id of each worker
// @return Status - The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override;
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param int32_t workerId - id of each worker
// @return Status - The error code return
// @return Status The status code returned
Status PrescanWorkerEntry(int32_t worker_id);
// Main Loop of ImageFolderOp
// Master thread: Fill IOBlockQueue, then goes to sleep
// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override;
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
// @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class
// @return Status - The error code return
// @return Status The status code returned
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
// A print method typically used for debugging
@ -224,19 +224,19 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return
// @return Status The status code returned
Status InitSampler();
// Load a tensor row according to a pair
// @param row_id_type row_id - id for this tensor row
// @param ImageLabelPair pair - <imagefile,label>
// @param TensorRow row - image & label read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadTensorRow(row_id_type row_id, ImageLabelPair pair, TensorRow *row);
// @param const std::vector<int64_t> &keys - keys in ioblock
// @param std::unique_ptr<DataBuffer> db
// @return Status - The error code return
// @return Status The status code returned
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);
// @param std::string & dir - dir to walk all images
@ -253,7 +253,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
Status LaunchThreadsAndInitOp();
// reset Op
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override;
// Private function for computing the assignment of the column name map.

View File

@ -58,12 +58,12 @@ class IOBlock {
// Fetches the first key from the block.
// @note Only useful if you know the block only has 1 key.
// @return A copy of the first key from the block
// @return Status - The error code return
// @return Status The status code returned
Status GetKey(int64_t *out_key) const;
// Fetches the list of keys from this block.
// @param out_keys - A copy of the vector of keys from the block.
// @return Status - The error code return
// @return Status The status code returned
Status GetKeys(std::vector<int64_t> *out_keys) const;
// Does this block have the eoe flag turned on?
@ -110,7 +110,7 @@ class FilenameBlock : public IOBlock {
// Gets the filename from the block using the provided index container
// @param out_filename - The filename to add to the block
// @param index - The index to perform lookup against
// @return Status - The error code return
// @return Status The status code returned
Status GetFilename(std::string *out_filename, const AutoIndexObj<std::string> &index) const;
// Get the start offset of file

View File

@ -110,12 +110,12 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
}
// Check validity of input args
// @return Status - The error code return
// @return Status The status code returned
Status SanityCheck();
// The builder "build" method creates the final object.
// @param std::shared_ptr<ManifestOp> *op - DatasetOp
// @return - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<ManifestOp> *op);
private:
@ -145,18 +145,18 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param int32_t worker_id - id of each worker
// @return Status - The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override;
// Main Loop of ManifestOp
// Master thread: Fill IOBlockQueue, then goes to sleep
// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override;
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
// @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class
// @return Status - The error code return
// @return Status The status code returned
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
// A print method typically used for debugging
@ -201,37 +201,37 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return
// @return Status The status code returned
Status InitSampler();
// Method in operator(), to fill IOBlockQueue
// @param std::unique_ptr<DataBuffer> sampler_buffer - to fill IOBlockQueue
// @return Status - The error code return
// @return Status The status code returned
Status AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer);
// Load a tensor row according to a pair
// @param row_id_type row_id - id for this tensor row
// @param std::pair<std::string, std::vector<std::string>> - <imagefile, <label1, label2...>>
// @param TensorRow row - image & label read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadTensorRow(row_id_type row_id, const std::pair<std::string, std::vector<std::string>> &data,
TensorRow *row);
// @param const std::vector<int64_t> &keys - keys in ioblock
// @param std::unique_ptr<DataBuffer> db
// @return Status - The error code return
// @return Status The status code returned
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);
// Parse manifest file to get image path and label and so on.
// @return Status - The error code return
// @return Status The status code returned
Status ParseManifestFile();
// Called first when function is called
// @return Status - The error code return
// @return Status The status code returned
Status LaunchThreadsAndInitOp();
// reset Op
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override;
// Check if image ia valid.Only support JPEG/PNG/GIF/BMP
@ -239,7 +239,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
Status CheckImageType(const std::string &file_name, bool *valid);
// Count label index,num rows and num samples
// @return Status - The error code return
// @return Status The status code returned
Status CountDatasetInfo();
// Private function for computing the assignment of the column name map.

View File

@ -164,13 +164,13 @@ class MindRecordOp : public ParallelOp {
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param int32_t workerId - id of each worker
// @return Status - The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override;
// Class functor operator () override.
// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work.
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override;
// Called first when function is called
@ -180,7 +180,7 @@ class MindRecordOp : public ParallelOp {
// Overrides base class reset method. When an operator does a reset, it cleans up any state
// info from it's previous execution and then initializes itself so that it can be executed
// again.
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override;
// Getter method

View File

@ -99,12 +99,12 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
return *this;
}
// Check validity of input args
// @return - The error code return
// @return Status The status code returned
Status SanityCheck();
// The builder "Build" method creates the final object.
// @param std::shared_ptr<MnistOp> *op - DatasetOp
// @return - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<MnistOp> *op);
private:
@ -133,18 +133,18 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param int32_t worker_id - id of each worker
// @return Status - The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override;
// Main Loop of MnistOp
// Master thread: Fill IOBlockQueue, then goes to sleep
// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override;
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
// @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class
// @return Status - The error code return
// @return Status The status code returned
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
// A print method typically used for debugging
@ -170,39 +170,39 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return
// @return Status The status code returned
Status InitSampler();
// Load a tensor row according to a pair
// @param row_id_type row_id - id for this tensor row
// @param ImageLabelPair pair - <imagefile,label>
// @param TensorRow row - image & label read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadTensorRow(row_id_type row_id, const MnistLabelPair &mnist_pair, TensorRow *row);
// @param const std::vector<int64_t> &keys - keys in ioblock
// @param std::unique_ptr<DataBuffer> db
// @return Status - The error code return
// @return Status The status code returned
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);
// Iterate through all members in sampleIds and fill them into IOBlock.
// @param std::shared_ptr<Tensor> sample_ids -
// @param std::vector<int64_t> *keys - keys in ioblock
// @return Status - The error code return
// @return Status The status code returned
Status TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys);
// Check image file stream.
// @param const std::string *file_name - image file name
// @param std::ifstream *image_reader - image file stream
// @param uint32_t num_images - returns the number of images
// @return Status - The error code return
// @return Status The status code returned
Status CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images);
// Check label stream.
// @param const std::string &file_name - label file name
// @param std::ifstream *label_reader - label file stream
// @param uint32_t num_labels - returns the number of labels
// @return Status - The error code return
// @return Status The status code returned
Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels);
// Read 4 bytes of data from a file stream.
@ -219,23 +219,23 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// @param std::ifstream *image_reader - image file stream
// @param std::ifstream *label_reader - label file stream
// @param int64_t read_num - number of image to read
// @return Status - The error code return
// @return Status The status code returned
Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index);
// Parse all mnist dataset files
// @return Status - The error code return
// @return Status The status code returned
Status ParseMnistData();
// Read all files in the directory
// @return Status - The error code return
// @return Status The status code returned
Status WalkAllFiles();
// Called first when function is called
// @return Status - The error code return
// @return Status The status code returned
Status LaunchThreadsAndInitOp();
// reset Op
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override;
// Private function for computing the assignment of the column name map.

View File

@ -63,7 +63,7 @@ class RandomDataOp : public ParallelOp {
/**
* The build method that produces the instantiated RandomDataOp as a shared pointer
* @param out_op - The output RandomDataOperator that was constructed
* @return Status - The error code return
* @return Status The status code returned
*/
Status Build(std::shared_ptr<RandomDataOp> *out_op);
@ -128,7 +128,7 @@ class RandomDataOp : public ParallelOp {
private:
/**
* Check if the required parameters are set by the builder.
* @return Status - The error code return
* @return Status The status code returned
*/
Status SanityCheck() const;
@ -182,7 +182,7 @@ class RandomDataOp : public ParallelOp {
* Class functor operator () override.
* All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
* provide the master loop that drives the logic for performing the work.
* @return Status - The error code return
* @return Status The status code returned
*/
Status operator()() override;
@ -190,7 +190,7 @@ class RandomDataOp : public ParallelOp {
* Overrides base class reset method. When an operator does a reset, it cleans up any state
* info from it's previous execution and then initializes itself so that it can be executed
* again.
* @return Status - The error code return
* @return Status The status code returned
*/
Status Reset() override;
@ -207,7 +207,7 @@ class RandomDataOp : public ParallelOp {
/**
* The entry point code for when workers are launched
* @param worker_id - The worker id
* @return Status - The error code return
* @return Status The status code returned
*/
Status WorkerEntry(int32_t worker_id) override;
@ -219,7 +219,7 @@ class RandomDataOp : public ParallelOp {
/**
* Performs a synchronization between workers at the end of an epoch
* @param worker_id - The worker id
* @return Status - The error code return
* @return Status The status code returned
*/
Status EpochSync(int32_t worker_id, bool *quitting);
@ -227,7 +227,7 @@ class RandomDataOp : public ParallelOp {
* A helper function to stuff the tensor table into a buffer and send it to output connector
* @param worker_id - The worker id
* @param in_table - The tensor table to pack and send
* @return Status - The error code return
* @return Status The status code returned
*/
Status PackAndSend(int32_t worker_id, std::unique_ptr<TensorQTable> in_table);
@ -235,7 +235,7 @@ class RandomDataOp : public ParallelOp {
* A helper function to create random data for the row
* @param worker_id - The worker id
* @param new_row - The output row to produce
* @return Status - The error code return
* @return Status The status code returned
*/
Status CreateRandomRow(int32_t worker_id, TensorRow *new_row);

View File

@ -40,7 +40,7 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED
// @param std::unique_ptr<DataBuffer pBuffer
// @param int32_t workerId
// @return - The error code return
// @return Status The status code returned
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
// first handshake between leaf source op and Sampler. This func will determine the amount of data
@ -53,7 +53,7 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED
Status InitSampler() override;
// for next epoch of sampleIds
// @return - The error code return
// @return Status The status code returned
Status ResetSampler() override;
// Printer for debugging purposes.

View File

@ -41,13 +41,13 @@ class PythonSamplerRT : public SamplerRT {
Status InitSampler() override;
// for next epoch of sampleIds
// @return - The error code return
// @return Status The status code returned
Status ResetSampler() override;
// Op calls this to get next Buffer that contains all the sampleIds
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op
// @param int32_t workerId - not meant to be used
// @return - The error code return
// @return Status The status code returned
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
// Printer for debugging purposes.

View File

@ -40,14 +40,14 @@ class RandomSamplerRT : public SamplerRT {
// Op calls this to get next Buffer that contains all the sampleIds
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param int32_t workerId - not meant to be used
// @return - The error code return
// @return Status The status code returned
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
// meant to be called by base class or python
Status InitSampler() override;
// for next epoch of sampleIds
// @return - The error code return
// @return Status The status code returned
Status ResetSampler() override;
void SamplerPrint(std::ostream &out, bool show_all) const override;

View File

@ -35,12 +35,12 @@ class RandomAccessOp {
public:
// Sampler get number of rows in the dataset
// @param int64_t num - return number of rows for this dataset
// @return - The error code return
// @return Status The status code returned
Status GetNumRowsInDataset(int64_t *num_rows) const;
// sampler gets label , imageIds from corresponding Dataset Op, this function is unique to PK
// @param std::map<int64_t, std::vector<int64_t>> * map
// @return - The error code return
// @return Status The status code returned
virtual Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *map) const {
RETURN_STATUS_UNEXPECTED("GetClassIds needs to be override to support PK");
}
@ -71,7 +71,7 @@ class SamplerRT {
// @note It is Sampler responsibility to make sure that the id is not out of bound.
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param int32_t workerId - not meant to be used
// @return - The error code return
// @return Status The status code returned
virtual Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) = 0;
// This function only called by python layer. Not needed by Android.
@ -81,7 +81,7 @@ class SamplerRT {
#endif
// for next epoch of sampleIds
// @return - The error code return
// @return Status The status code returned
virtual Status ResetSampler() = 0;
// first handshake between leaf source op and Sampler. This func will determine the amount of data
@ -114,13 +114,13 @@ class SamplerRT {
// Adds a sampler to become our child.
// @param std::shared_ptr<DatasetOp> - The sampler to add as a child.
// @return - The error code returned.
// @return Status The status code returned
Status AddChild(std::shared_ptr<SamplerRT> child);
// A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler
// @param std::shared_ptr<Tensor>* sampleIds
// @param int64_t numElements - must be a non 0 number
// @return - The error code returned.
// @return Status The status code returned
Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements);
// A print method typically used for debugging
@ -146,7 +146,7 @@ class SamplerRT {
// associated id.
// @param int64_t* out_associated_id - Out parameter, contains the associated id.
// @param int64_t id - The id used as an index to get the associated child id.
// @return - The error code returned.
// @return Status The status code returned
Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id);
protected:

View File

@ -40,13 +40,13 @@ class SequentialSamplerRT : public SamplerRT {
Status InitSampler() override;
// for next epoch of sampleIds
// @return - The error code return
// @return Status The status code returned
Status ResetSampler() override;
// Op calls this to get next Buffer that contains all the sampleIds
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op
// @param int32_t workerId - not meant to be used
// @return - The error code return
// @return Status The status code returned
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
// Printer for debugging purposes.

View File

@ -132,12 +132,12 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
}
// Check validity of input args
// @return = The error code return
// @return Status The status code returned
Status SanityCheck();
// The builder "Build" method creates the final object.
// @param std::shared_ptr<VOCOp> *op - DatasetOp
// @return - The error code return
// @return Status The status code returned
Status Build(std::shared_ptr<VOCOp> *op);
private:
@ -173,13 +173,13 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
// @param int32_t workerId - id of each worker
// @return Status - The error code return
// @return Status The status code returned
Status WorkerEntry(int32_t worker_id) override;
// Main Loop of VOCOp
// Master thread: Fill IOBlockQueue, then goes to sleep
// Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override;
// A print method typically used for debugging
@ -222,55 +222,55 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return
// @return Status The status code returned
Status InitSampler();
// Load a tensor row according to image id
// @param row_id_type row_id - id for this tensor row
// @param std::string image_id - image id
// @param TensorRow row - image & target read into this tensor row
// @return Status - The error code return
// @return Status The status code returned
Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row);
// @param const std::string &path - path to the image file
// @param const ColDescriptor &col - contains tensor implementation and datatype
// @param std::shared_ptr<Tensor> tensor - return
// @return Status - The error code return
// @return Status The status code returned
Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor);
// @param const std::string &path - path to the image file
// @param TensorRow *row - return
// @return Status - The error code return
// @return Status The status code returned
Status ReadAnnotationToTensor(const std::string &path, TensorRow *row);
// @param const std::vector<uint64_t> &keys - keys in ioblock
// @param std::unique_ptr<DataBuffer> db
// @return Status - The error code return
// @return Status The status code returned
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);
// Read image list from ImageSets
// @return Status - The error code return
// @return Status The status code returned
Status ParseImageIds();
// Read annotation from Annotation folder
// @return Status - The error code return
// @return Status The status code returned
Status ParseAnnotationIds();
// @param const std::string &path - path to annotation xml
// @return Status - The error code return
// @return Status The status code returned
Status ParseAnnotationBbox(const std::string &path);
// @param const std::shared_ptr<Tensor> &sample_ids - sample ids of tensor
// @param std::vector<int64_t> *keys - image id
// @return Status - The error code return
// @return Status The status code returned
Status TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys);
// Called first when function is called
// @return Status - The error code return
// @return Status The status code returned
Status LaunchThreadsAndInitOp();
// Reset dataset state
// @return Status - The error code return
// @return Status The status code returned
Status Reset() override;
// Private function for computing the assignment of the column name map.

View File

@ -75,7 +75,7 @@ class TakeOp : public PipelineOp {
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override;
// Base-class override for NodePass visitor acceptor.

View File

@ -101,7 +101,7 @@ class ZipOp : public PipelineOp {
// Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - The error code return
// @return Status The status code returned
Status operator()() override;
/// \brief Base-class override for NodePass pre-visit acceptor

View File

@ -226,7 +226,7 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
// Compulsory transformation/action post optimization.
// For example, repeatOp inlining
//
// @return Status - The error code return
// @return Status The status code returned
Status ExecutionTree::Prepare(int32_t num_epochs, bool partial) {
num_epochs_ = num_epochs;
partially_prepare_ = partial;

View File

@ -115,16 +115,16 @@ class ExecutionTree {
// provides it with a link to the tree. A node cannot form any relationships (parent/child) with
// other nodes unless they are associated with the same tree.
// @param op - The operator to associate
// @return Status - The error code return
// @return Status The status code returned
Status AssociateNode(const std::shared_ptr<DatasetOp> &op);
// Sets the root node of the tree
// @param op - The operator to assign as root
// @return Status - The error code return
// @return Status The status code returned
Status AssignRoot(const std::shared_ptr<DatasetOp> &op);
// Start the execution of the tree
// @return Status - The error code return
// @return Status The status code returned
Status Launch();
/// A print method typically used for debugging
@ -155,7 +155,7 @@ class ExecutionTree {
// wrapper for the TaskGroup handling that is stored inside the execution tree.
// @param num_workers - The number of workers to launch
// @param func - The function entry point that workers will execute
// @return Status - The error code return
// @return Status The status code returned
Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name = "");
// Getter method
@ -181,32 +181,32 @@ class ExecutionTree {
// Compulsory transformation/action post optimization.
// For example, repeatOp inlining
//
// @return Status - The error code return
// @return Status The status code returned
Status Prepare(int num_epochs = -1, bool partial = false);
// Compulsory transformation/action pre optimization.
// @return Status - The error code return
// @return Status The status code returned
Status PreAction();
// Compulsory transformation/action post optimization.
// @return Status - The error code return
// @return Status The status code returned
Status PostAction();
// Optimization transformation/action, optional.
// @return Status - The error code return
// @return Status The status code returned
Status Optimize();
// The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
// it ready for execution.
// @param Total number of epochs that will be run on this tree
// @return Status - The error code return
// @return Status The status code returned
Status PrepareDeprecated();
// Recursive function used during prepare phase to visit a node and drive any pre- and post-
// node actions during a tree walk.
// @param op - The dataset op to work on
// @return Status - The error code return
// @return Status The status code returned
Status PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op);
// Return the pointer to the TaskGroup

View File

@ -51,7 +51,7 @@ class Edge {
// Get the feature of a edge
// @param FeatureType feature_type - type of feature
// @param std::shared_ptr<Feature> *out_feature - Returned feature
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0;
// Get nodes on the edge
@ -71,7 +71,7 @@ class Edge {
// Update feature of edge
// @param std::shared_ptr<Feature> feature -
// @return Status - The error code return
// @return Status The status code returned
virtual Status UpdateFeature(const std::shared_ptr<Feature> &feature) = 0;
protected:

View File

@ -47,19 +47,19 @@ class GraphData {
// Get all nodes from the graph.
// @param NodeType node_type - type of node
// @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) = 0;
// Get all edges from the graph.
// @param NodeType edge_type - type of edge
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) = 0;
// Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0;
// All neighbors of the acquisition node.
@ -68,7 +68,7 @@ class GraphData {
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
// is not enough, fill in tensor as -1.
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) = 0;
@ -77,7 +77,7 @@ class GraphData {
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) = 0;
@ -87,7 +87,7 @@ class GraphData {
// @param NodeIdType samples_num - Number of neighbors sampled
// @param NodeType neg_neighbor_type - The type of negative neighbor.
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) = 0;
@ -98,7 +98,7 @@ class GraphData {
// @param float step_away_param - inout hyper parameter in node2vec algorithm
// @param NodeIdType default_node - default node id
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
// @return Status - The error code return
// @return Status The status code returned
virtual Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out) = 0;
@ -108,7 +108,7 @@ class GraphData {
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param TensorRow *out - Returned features
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
TensorRow *out) = 0;
@ -117,7 +117,7 @@ class GraphData {
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param Tensor *out - Returned features
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
TensorRow *out) = 0;

View File

@ -57,19 +57,19 @@ class GraphDataClient : public GraphData {
// Get all nodes from the graph.
// @param NodeType node_type - type of node
// @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return
// @return Status The status code returned
Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override;
// Get all edges from the graph.
// @param NodeType edge_type - type of edge
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return
// @return Status The status code returned
Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override;
// Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return
// @return Status The status code returned
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;
// All neighbors of the acquisition node.
@ -78,7 +78,7 @@ class GraphDataClient : public GraphData {
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
// is not enough, fill in tensor as -1.
// @return Status - The error code return
// @return Status The status code returned
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) override;
@ -87,7 +87,7 @@ class GraphDataClient : public GraphData {
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return
// @return Status The status code returned
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;
@ -96,7 +96,7 @@ class GraphDataClient : public GraphData {
// @param NodeIdType samples_num - Number of neighbors sampled
// @param NodeType neg_neighbor_type - The type of negative neighbor.
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return
// @return Status The status code returned
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override;
@ -107,7 +107,7 @@ class GraphDataClient : public GraphData {
// @param float step_away_param - inout hyper parameter in node2vec algorithm
// @param NodeIdType default_node - default node id
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
// @return Status - The error code return
// @return Status The status code returned
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out) override;
@ -117,7 +117,7 @@ class GraphDataClient : public GraphData {
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param TensorRow *out - Returned features
// @return Status - The error code return
// @return Status The status code returned
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
TensorRow *out) override;
@ -126,7 +126,7 @@ class GraphDataClient : public GraphData {
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param Tensor *out - Returned features
// @return Status - The error code return
// @return Status The status code returned
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
TensorRow *out) override;

View File

@ -51,19 +51,19 @@ class GraphDataImpl : public GraphData {
// Get all nodes from the graph.
// @param NodeType node_type - type of node
// @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return
// @return Status The status code returned
Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override;
// Get all edges from the graph.
// @param NodeType edge_type - type of edge
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return
// @return Status The status code returned
Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override;
// Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return
// @return Status The status code returned
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;
// All neighbors of the acquisition node.
@ -72,7 +72,7 @@ class GraphDataImpl : public GraphData {
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
// is not enough, fill in tensor as -1.
// @return Status - The error code return
// @return Status The status code returned
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) override;
@ -81,7 +81,7 @@ class GraphDataImpl : public GraphData {
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return
// @return Status The status code returned
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;
@ -90,7 +90,7 @@ class GraphDataImpl : public GraphData {
// @param NodeIdType samples_num - Number of neighbors sampled
// @param NodeType neg_neighbor_type - The type of negative neighbor.
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return
// @return Status The status code returned
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override;
@ -101,7 +101,7 @@ class GraphDataImpl : public GraphData {
// @param float step_away_param - inout hyper parameter in node2vec algorithm
// @param NodeIdType default_node - default node id
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
// @return Status - The error code return
// @return Status The status code returned
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out) override;
@ -111,7 +111,7 @@ class GraphDataImpl : public GraphData {
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param TensorRow *out - Returned features
// @return Status - The error code return
// @return Status The status code returned
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
TensorRow *out) override;
@ -123,7 +123,7 @@ class GraphDataImpl : public GraphData {
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param Tensor *out - Returned features
// @return Status - The error code return
// @return Status The status code returned
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
TensorRow *out) override;
@ -132,7 +132,7 @@ class GraphDataImpl : public GraphData {
// Get meta information of graph
// @param MetaInfo *meta_info - Returned meta information
// @return Status - The error code return
// @return Status The status code returned
Status GetMetaInfo(MetaInfo *meta_info);
#ifdef ENABLE_PYTHON
@ -202,14 +202,14 @@ class GraphDataImpl : public GraphData {
};
// Load graph data from mindrecord file
// @return Status - The error code return
// @return Status The status code returned
Status LoadNodeAndEdge();
// Create Tensor By Vector
// @param std::vector<std::vector<T>> &data -
// @param DataType type -
// @param std::shared_ptr<Tensor> *out -
// @return Status - The error code return
// @return Status The status code returned
template <typename T>
Status CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type, std::shared_ptr<Tensor> *out);
@ -217,32 +217,32 @@ class GraphDataImpl : public GraphData {
// @param std::vector<std::vector<T>> *data - To be completed vector
// @param size_t max_size - The size of the completed vector
// @param T default_value - Filled default
// @return Status - The error code return
// @return Status The status code returned
template <typename T>
Status ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value);
// Get the default feature of a node
// @param FeatureType feature_type -
// @param std::shared_ptr<Feature> *out_feature - Returned feature
// @return Status - The error code return
// @return Status The status code returned
Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature);
// Get the default feature of a edge
// @param FeatureType feature_type -
// @param std::shared_ptr<Feature> *out_feature - Returned feature
// @return Status - The error code return
// @return Status The status code returned
Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature);
// Find node object using node id
// @param NodeIdType id -
// @param std::shared_ptr<Node> *node - Returned node object
// @return Status - The error code return
// @return Status The status code returned
Status GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node);
// Find edge object using edge id
// @param EdgeIdType id -
// @param std::shared_ptr<Node> *edge - Returned edge object
// @return Status - The error code return
// @return Status The status code returned
Status GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge);
// Negative sampling
@ -250,7 +250,7 @@ class GraphDataImpl : public GraphData {
// @param std::unordered_set<NodeIdType> &exclude_data - Data to be excluded
// @param int32_t samples_num -
// @param std::vector<NodeIdType> *out_samples - Sampling results returned
// @return Status - The error code return
// @return Status The status code returned
Status NegativeSample(const std::vector<NodeIdType> &data, const std::vector<NodeIdType> shuffled_ids,
size_t *start_index, const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num,
std::vector<NodeIdType> *out_samples);

View File

@ -43,12 +43,12 @@ class LocalEdge : public Edge {
// Get the feature of a edge
// @param FeatureType feature_type - type of feature
// @param std::shared_ptr<Feature> *out_feature - Returned feature
// @return Status - The error code return
// @return Status The status code returned
Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override;
// Update feature of edge
// @param std::shared_ptr<Feature> feature -
// @return Status - The error code return
// @return Status The status code returned
Status UpdateFeature(const std::shared_ptr<Feature> &feature) override;
private:

View File

@ -40,13 +40,13 @@ class LocalNode : public Node {
// Get the feature of a node
// @param FeatureType feature_type - type of feature
// @param std::shared_ptr<Feature> *out_feature - Returned feature
// @return Status - The error code return
// @return Status The status code returned
Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override;
// Get the all neighbors of a node
// @param NodeType neighbor_type - type of neighbor
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return
// @return Status The status code returned
Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors,
bool exclude_itself = false) override;
@ -54,18 +54,18 @@ class LocalNode : public Node {
// @param NodeType neighbor_type - type of neighbor
// @param int32_t samples_num - Number of neighbors to be acquired
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return
// @return Status The status code returned
Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num,
std::vector<NodeIdType> *out_neighbors) override;
// Add neighbor of node
// @param std::shared_ptr<Node> node -
// @return Status - The error code return
// @return Status The status code returned
Status AddNeighbor(const std::shared_ptr<Node> &node) override;
// Update feature of node
// @param std::shared_ptr<Feature> feature -
// @return Status - The error code return
// @return Status The status code returned
Status UpdateFeature(const std::shared_ptr<Feature> &feature) override;
private:

View File

@ -49,13 +49,13 @@ class Node {
// Get the feature of a node
// @param FeatureType feature_type - type of feature
// @param std::shared_ptr<Feature> *out_feature - Returned feature
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0;
// Get the all neighbors of a node
// @param NodeType neighbor_type - type of neighbor
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors,
bool exclude_itself = false) = 0;
@ -63,18 +63,18 @@ class Node {
// @param NodeType neighbor_type - type of neighbor
// @param int32_t samples_num - Number of neighbors to be acquired
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status - The error code return
// @return Status The status code returned
virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num,
std::vector<NodeIdType> *out_neighbors) = 0;
// Add neighbor of node
// @param std::shared_ptr<Node> node -
// @return Status - The error code return
// @return Status The status code returned
virtual Status AddNeighbor(const std::shared_ptr<Node> &node) = 0;
// Update feature of node
// @param std::shared_ptr<Feature> feature -
// @return Status - The error code return
// @return Status The status code returned
virtual Status UpdateFeature(const std::shared_ptr<Feature> &feature) = 0;
protected:

View File

@ -57,6 +57,10 @@ class RepeatNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Getter
/// \return Number of cycles to repeat the execution
const int32_t Count() const { return repeat_count_; }
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting

View File

@ -55,6 +55,10 @@ class SkipNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Getter
/// \return Number of rows to skip
const int32_t Count() const { return skip_count_; }
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting

View File

@ -55,6 +55,10 @@ class TakeNode : public DatasetNode {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Getter
/// \return Number of rows to output
const int32_t Count() const { return take_count_; }
/// \brief Base-class override for GetDatasetSize
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting

View File

@ -29,7 +29,7 @@ class TensorOpFusionPass : public NodePass {
/// \brief Identifies and fuses tensor ops within MapOp
/// \param[in] node The node being visited
/// \param[inout] *modified indicates whether the node has been visited
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
};
} // namespace dataset

View File

@ -135,7 +135,7 @@ class IRTreePass : public IRPass {
/// "modified" flag needs to be set to true if tree is modified during the pass execution.
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate if the tree was modified.
/// \return Status The error code return
/// \return Status The status code returned
virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); }
};
@ -170,14 +170,14 @@ class IRNodePass : public IRPass {
/// a tree traversal. "modified" flag needs to be set to true if node is modified during the pass execution
/// \param[in] node The node being visited
/// \param[out] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
virtual Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); }
/// \brief Derived classes may implement the VisitAfter function to implement node level tree transformation
/// "modified" flag needs to be set to true if node is modified during the pass execution
/// \param[in] node The node being visited
/// \param[out] modified Indicator if the node was changed at all.
/// \return Status The error code return
/// \return Status The status code returned
virtual Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); }
// Visit()/VisitAfter() method to be overridden.
@ -266,7 +266,7 @@ class TreePass : public Pass {
/// "modified" flag needs to be set to true if tree is modified during the pass execution.
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
/// \return Status The status code returned
virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); }
};
@ -301,14 +301,14 @@ class NodePass : public Pass {
/// a tree traversal. "modified" flag needs to be set to true if tree is modified during the pass execution
/// \param[in] node The node being visited
/// \param[out] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
virtual Status PreRunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); }
/// \brief Derived classes may implement the RunOnNode function to implement node level tree transformation
/// "modified" flag needs to be set to true if tree is modified during the pass execution
/// \param[in] node The node being visited
/// \param[out] modified Indicator if the node was changed at all.
/// \return Status The error code return
/// \return Status The status code returned
virtual Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); }
// Visit methods to be overridden.

View File

@ -41,80 +41,80 @@ class RepeatPass : public NodePass {
/// \brief Identifies the subtree below this node as being in a repeated path of the tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
/// \brief Identifies the subtree below this node as being in a repeated path of the tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override;
/// \brief Identifies the subtree below this node as being in a cache merge path
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override;
/// \brief Identifies the subtree below this node as being cached
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Hooks up any identified eoe nodes under this repeat.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
/// \brief Hooks up any identified eoe nodes under this repeat.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override;
/// \brief CacheOp removes previous leaf ops and replaces them with itself
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Turns of the tracking for operations under merge op
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override;
/// \brief Saves the lookup up in case it needs to be referenced by a repeat
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) override;
/// \brief Set the epoch count for DeviceQueue
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override;
/// \brief Special case for GeneratorOp
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;
/// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
/// for use with a controlling repeat above it.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override;
private:
/// \brief Adds an operator to the eoe operator stack save area
/// \param op - The dataset op to work add to eoe stack
/// \return Status - The error code return
/// \return Status The status code returned
void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op);
/// \brief Pops an operator from the eoe operator stack save area
@ -127,7 +127,7 @@ class RepeatPass : public NodePass {
/// \brief Adds an operator to the cached operator stack save area
/// \param op - The dataset op to work add to cached stack
/// \return Status - The error code return
/// \return Status The status code returned
void AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op);
/// \brief Pops an operator from the cached operator stack save area

View File

@ -38,123 +38,123 @@ class CacheErrorPass : public NodePass {
/// \brief Identifies the subtree below this node as being cached
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Returns an error if ZipOp exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified) override;
/// \brief Returns an error if MapOp with non-deterministic TensorOps exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
/// \brief Returns an error if ConcatOp exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override;
/// \brief Returns an error if TakeOp exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override;
/// \brief Returns an error if SkipOp exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override;
/// \brief Returns an error if SkipOp exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<BatchOp> node, bool *modified) override;
#ifdef ENABLE_PYTHON
/// \brief Returns an error if FilterOp exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override;
#endif
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override;
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override;
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override;
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override;
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override;
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override;
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override;
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override;
/// \brief Identifies the leaf dataset as being mappable
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;
/// \brief Identifies the subtree above this node as not being cached
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Identifies and block repeat under cache scenarios
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
private:

View File

@ -48,14 +48,14 @@ class CacheTransformPass : public TreePass {
/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Resets the tracking of the cache within the tree and assigns the operators that
/// will be involved in a cache transformation
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
#ifndef ENABLE_ANDROID
@ -63,95 +63,95 @@ class CacheTransformPass : public TreePass {
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) override;
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) override;
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) override;
#endif
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override;
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override;
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override;
#ifdef ENABLE_PYTHON
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override;
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override;
#endif
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override;
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override;
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override;
#ifndef ENABLE_ANDROID
/// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override;
#endif
@ -161,12 +161,12 @@ class CacheTransformPass : public TreePass {
private:
/// \brief Common code for mappable leaf setup.
/// \param[in] node The leaf node performing setup work.
/// \return Status The error code return
/// \return Status The status code returned
Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
/// \brief Common code for non-mappable leaf setup.
/// \param[in] node The leaf node performing setup work.
/// \return Status The error code return
/// \return Status The status code returned
Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
/// \brief Assigns the leaf and cache operators that are involved in a cache transformation
@ -191,7 +191,7 @@ class CacheTransformPass : public TreePass {
/// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
private:
@ -212,7 +212,7 @@ class CacheTransformPass : public TreePass {
/// \param[in] leaf_op The leaf node in the transform
/// \param[in] cache_op The cache op in the transform (will get removed)
/// \param[in] cache_client The cache client
/// \return Status The error code return
/// \return Status The status code returned
Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op,
std::shared_ptr<DatasetOp> cache_op, std::shared_ptr<CacheClient> cache_client);
};

View File

@ -38,61 +38,61 @@ class CacheValidationPass : public IRNodePass {
/// \brief Returns an error if BatchNode exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<BatchNode> node, bool *modified) override;
/// \brief Returns an error if ConcatNode exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<ConcatNode> node, bool *modified) override;
/// \brief Returns an error if FilterNode exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<FilterNode> node, bool *modified) override;
/// \brief Returns an error if SkipNode exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<SkipNode> node, bool *modified) override;
/// \brief Returns an error if TakeNode exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<TakeNode> node, bool *modified) override;
/// \brief Returns an error if ZipNode exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<ZipNode> node, bool *modified) override;
/// \brief Returns an error if MapNode with non-deterministic tensor operations exists under a cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<MapNode> node, bool *modified) override;
/// \brief Returns an error if there is a cache over another cache
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override;
/// \brief Identifies and block repeat under cache scenarios
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) override;
/// \brief Identifies the subtree above this node as not being cached
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override;
private:

View File

@ -45,27 +45,27 @@ class EpochCtrlPass : public IRTreePass {
/// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<RootNode> node, bool *modified) override;
/// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) override;
#ifndef ENABLE_ANDROID
/// \brief Performs finder work for BuildSentenceVocabNode that has special rules about epoch control injection.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) override;
#endif
/// \brief Register the TransferNode for further action.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) override;
/// \brief Getter
@ -89,7 +89,7 @@ class EpochCtrlPass : public IRTreePass {
/// \brief Runs an injection pass to inject in operators needed at the pre pass stage
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override;
};
} // namespace dataset

View File

@ -46,20 +46,20 @@ class EpochInjectionPass : public TreePass {
/// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) override;
/// \brief Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) override;
#endif
/// \brief Register the DeviceQueueOp for further action.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override;
/// \brief Getter
@ -79,7 +79,7 @@ class EpochInjectionPass : public TreePass {
/// \brief Runs an injection pass to inject in operators needed at the pre pass stage
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
};
} // namespace dataset

View File

@ -17,7 +17,10 @@
#include <vector>
#include <algorithm>
#include "minddata/dataset/engine/opt/pre/node_removal_pass.h"
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
namespace mindspore {
namespace dataset {
@ -47,7 +50,16 @@ Status NodeRemovalPass::RemovalNodes::VisitAfter(std::shared_ptr<DatasetNode> no
return Status::OK();
}
// Perform ShuffleOp removal check.
// Perform RepeatNode removal check.
Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<RepeatNode> node, bool *modified) {
*modified = false;
if (node->Count() == 1) {
nodes_to_remove_.push_back(std::static_pointer_cast<DatasetNode>(node));
}
return Status::OK();
}
// Perform ShuffleNode removal check.
Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) {
*modified = false;
#if 0
@ -60,6 +72,24 @@ Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<ShuffleNode> node, b
return Status::OK();
}
// Perform SkipNode removal check.
Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<SkipNode> node, bool *modified) {
*modified = false;
if (node->Count() == 0) {
nodes_to_remove_.push_back(std::static_pointer_cast<DatasetNode>(node));
}
return Status::OK();
}
// Perform TakeNode removal check.
Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<TakeNode> node, bool *modified) {
*modified = false;
if (node->Count() == -1) {
nodes_to_remove_.push_back(std::static_pointer_cast<DatasetNode>(node));
}
return Status::OK();
}
// constructor
NodeRemovalPass::NodeRemovalPass() {}

View File

@ -45,21 +45,39 @@ class NodeRemovalPass : public IRTreePass {
/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override;
/// \brief Resets the tracking of the cache within the tree
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override;
/// \brief Perform RepeatNode removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<RepeatNode> node, bool *modified) override;
/// \brief Perform ShuffleNode removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status Visit(std::shared_ptr<ShuffleNode> node, bool *modified) override;
/// \brief Perform SkipNode removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<SkipNode> node, bool *modified) override;
/// \brief Perform TakeNode removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<TakeNode> node, bool *modified) override;
/// \brief Getter
/// \return All the nodes to be removed
std::vector<std::shared_ptr<DatasetNode>> nodes_to_remove() { return nodes_to_remove_; }
@ -79,7 +97,7 @@ class NodeRemovalPass : public IRTreePass {
/// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override;
};
} // namespace dataset

View File

@ -46,20 +46,20 @@ class RemovalPass : public TreePass {
/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Resets the tracking of the cache within the tree
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
#endif
/// \brief Perform ShuffleOp removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
/// \brief Getter
@ -81,7 +81,7 @@ class RemovalPass : public TreePass {
/// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
/// \return Status The status code returned
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
};
} // namespace dataset

View File

@ -53,7 +53,7 @@ class ConnectorSize : public Sampling {
std::string Name() const override { return kConnectorSizeSamplingName; }
// Save sampling data to file
// @return Status - The error code return
// @return Status The status code returned
Status SaveToFile() override;
Status Init(const std::string &dir_path, const std::string &device_id) override;

View File

@ -65,7 +65,7 @@ class ConnectorThroughput : public Sampling {
std::string Name() const override { return name_; };
// Save sampling data to file
// @return Status - The error code return
// @return Status The status code returned
Status SaveToFile() override;
Status Init(const std::string &dir_path, const std::string &device_id);

View File

@ -32,13 +32,13 @@ class DatasetIteratorTracing : public Tracing {
~DatasetIteratorTracing() override = default;
// Record tracing data
// @return Status - The error code return
// @return Status The status code returned
Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value);
std::string Name() const override { return kDatasetIteratorTracingName; };
// Save tracing data to file
// @return Status - The error code return
// @return Status The status code returned
Status SaveToFile() override;
Status Init(const std::string &dir_path, const std::string &device_id) override;

View File

@ -32,13 +32,13 @@ class DeviceQueueTracing : public Tracing {
~DeviceQueueTracing() override = default;
// Record tracing data
// @return Status - The error code return
// @return Status The status code returned
Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value);
std::string Name() const override { return kDeviceQueueTracingName; };
// Save tracing data to file
// @return Status - The error code return
// @return Status The status code returned
Status SaveToFile() override;
Status Init(const std::string &dir_path, const std::string &device_id) override;

View File

@ -87,19 +87,19 @@ class ProfilingManager {
Status Initialize();
// Save profile data to file
// @return Status - The error code return
// @return Status The status code returned
Status SaveProfilingData();
// Sampling node getter
// @param name - The name of the requested node
// @param node - Pointer to the shared pointer for the Sampling node
// @return Status - The error code return
// @return Status The status code returned
Status GetSamplingNode(const std::string &name, std::shared_ptr<Sampling> *node);
// Tracing node getter
// @param name - The name of the requested node
// @param node - Pointer to the shared pointer for the Tracing node
// @return Status - The error code return
// @return Status The status code returned
Status GetTracingNode(const std::string &name, std::shared_ptr<Tracing> *node);
// If profiling is enabled.
@ -120,12 +120,12 @@ class ProfilingManager {
// Register profile node to tree
// @param node - Profiling node
// @return Status - The error code return
// @return Status The status code returned
Status RegisterTracingNode(std::shared_ptr<Tracing> node);
// Register profile node to tree
// @param node - Profiling node
// @return Status - The error code return
// @return Status The status code returned
Status RegisterSamplingNode(std::shared_ptr<Sampling> node);
ExecutionTree *tree_ = nullptr; // ExecutionTree pointer

View File

@ -442,7 +442,7 @@ class SchemaObj {
Status parse_column(nlohmann::json columns);
/// \brief Get schema file from JSON file
/// \param[in] json_obj Object of JSON parsed.
/// \param[in] json_obj parsed JSON object
/// \return Status code
Status from_json(nlohmann::json json_obj);

View File

@ -77,7 +77,7 @@ Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
// @param std::shared_ptr<Tensor> *dst - return tensor padded
// @param std::vector<dsize_t> pad_shape - shape to pad to
// @param std::shared_ptr<Tensor> pad_val - value to pad with in Tensor format,
// @return - The error code return
// @return Status The status code returned
Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
const std::shared_ptr<Tensor> &pad_val);
@ -86,7 +86,7 @@ Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
// @param std::shared_ptr<Tensor> *dst - return tensor padded
// @param std::vector<dsize_t> pad_shape - shape to pad to
// @param float pad_val - value to pad with
// @return - The error code return
// @return Status The status code returned
Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
const std::vector<dsize_t> &pad_shape, float pad_val);
@ -98,7 +98,7 @@ Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor>
// @param std::vector<dsize_t> cur_ind - recursion helper
// @param T pad_val - value to pad tensor with
// @param size_t cur_dim - recursion helper
// @return Status - The error code return
// @return Status The status code returned
Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst,
std::vector<dsize_t> cur_ind, size_t cur_dim = 0);
@ -107,7 +107,7 @@ Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<T
// @param std::shared_ptr<Tensor> *dst - return tensor padded
// @param std::vector<dsize_t> pad_shape - shape to pad to
// @param std::string pad_val - value to pad with
// @return - The error code return
// @return Status The status code returned
Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
const std::vector<dsize_t> &pad_shape, const std::string &pad_val);
@ -119,7 +119,7 @@ Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor>
// @param std::vector<dsize_t> cur_ind - recursion helperas text
// @param std::string pad_val - value to pad tensor with
// @param size_t cur_dim - recursion helper
// @return Status - The error code return
// @return Status The status code returned
Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::string> *dst,
const TensorShape &dst_shape, std::vector<dsize_t> cur_ind, size_t cur_dim,
const std::string &pad_value);

View File

@ -36,7 +36,7 @@ class ToFloat16Op : public TensorOp {
// Overrides the base class compute function
// Calls the ToFloat16 function in ImageUtils, this function takes an input tensor
// and transforms its data to float16, the output memory is manipulated to contain the result
// @return Status - The error code return
// @return Status The status code returned
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;

View File

@ -58,7 +58,7 @@ class CutOutOp : public TensorOp {
// Overrides the base class compute function
// Calls the erase function in ImageUtils, this function takes an input tensor
// and overwrites some of its data using openCV, the output memory is manipulated to contain the result
// @return Status - The error code return
// @return Status The status code returned
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kCutOutOp; }

View File

@ -50,7 +50,7 @@ class RandomColorAdjustOp : public TensorOp {
// Overrides the base class compute function.
// Calls multiple transform functions in ImageUtils, this function takes an input tensor.
// and transforms its data using openCV, the output memory is manipulated to contain the result.
// @return Status - The error code return.
// @return Status The status code returned.
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kRandomColorAdjustOp; }

View File

@ -61,7 +61,7 @@ class RandomRotationOp : public TensorOp {
// Overrides the base class compute function
// Calls the rotate function in ImageUtils, this function takes an input tensor
// and transforms its data using openCV, the output memory is manipulated to contain the result
// @return Status - The error code return
// @return Status The status code returned
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;

View File

@ -43,7 +43,7 @@ class UniformAugOp : public TensorOp {
void Print(std::ostream &out) const override { out << Name() << ":: number of ops " << num_ops_; }
// Overrides the base class compute function
// @return Status - The error code return
// @return Status The status code returned
Status Compute(const TensorRow &input, TensorRow *output) override;
std::string Name() const override { return kUniformAugOp; }

View File

@ -53,7 +53,7 @@ class DataHelper {
/// \param key Key of field to write to
/// \param value Value array to write to file
/// \param out_file Optional input for output file path, will write to input file if not specified
/// \return Status The error code return
/// \return Status The status code returned
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<std::string> &value,
const std::string &out_file = "");
@ -62,7 +62,7 @@ class DataHelper {
/// \param key Key of field to write to
/// \param value Value array to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The error code return
/// \return Status The status code returned
template <typename T>
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<T> &value,
const std::string &out_file = "") {
@ -99,7 +99,7 @@ class DataHelper {
/// \param key Key of field to write to
/// \param value Value to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The error code return
/// \return Status The status code returned
template <typename T>
Status UpdateValue(const std::string &in_file, const std::string &key, const T &value,
const std::string &out_file = "") {
@ -134,7 +134,7 @@ class DataHelper {
/// \brief Template function to write tensor to file
/// \param[in] in_file File to write to
/// \param[in] data Array of type T values
/// \return Status The error code return
/// \return Status The status code returned
template <typename T>
Status WriteBinFile(const std::string &in_file, const std::vector<T> &data) {
try {
@ -157,7 +157,7 @@ class DataHelper {
/// \param[in] in_file File name to write to
/// \param[in] data Pointer to data
/// \param[in] length Length of values to write from pointer
/// \return Status The error code return
/// \return Status The status code returned
template <typename T>
Status WriteBinFile(const std::string &in_file, T *data, size_t length) {
try {
@ -188,7 +188,7 @@ class DataHelper {
/// note This function will return okay even if key not found
/// \param[in] in_file Json file to remove key from
/// \param[in] key The key to remove
/// \return Status The error code return
/// \return Status The status code returned
Status RemoveKey(const std::string &in_file, const std::string &key, const std::string &out_file = "");
/// \brief A print method typically used for debugging

View File

@ -669,8 +669,6 @@ class Dataset:
>>> repeat_and_shuffle = data.repeat(50)
>>> repeat_and_shuffle = repeat_and_shuffle.shuffle(10)
"""
if count == 1:
return self
return RepeatDataset(self, count)
@check_skip
@ -717,8 +715,6 @@ class Dataset:
>>> # Create a dataset where the dataset includes 50 elements.
>>> data = data.take(50)
"""
if count == -1:
return self
return TakeDataset(self, count)
def _get_absolute_split_sizes(self, sizes):

View File

@ -1311,6 +1311,51 @@ TEST_F(MindDataTestPipeline, TestSkipDataset) {
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestSkipTakeRepeat) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipTakeRepeat.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 6));
// Create a Skip operation on ds
int32_t count = 0;
ds = ds->Skip(count);
// Create a Project operation on ds
std::vector<std::string> column_project = {"image"};
ds = ds->Project(column_project);
// Add a Take(-1)
ds = ds->Take(-1);
// Add a Repeat(1)
ds = ds->Repeat(1);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// iterate over the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
MS_LOG(INFO) << "Number of rows: " << i;
// Expect 6 rows
EXPECT_EQ(i, 6);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestSkipGetDatasetSize) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipGetDatasetSize.";

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest
import mindspore.dataset as ds
from mindspore import log as logger
@ -163,7 +163,7 @@ def test_take_08():
def test_take_09():
"""
Test take: repeat count is -1, and read the whole dataset, take after repeat
Test take: take count is -1, and read the whole dataset, take after repeat
"""
logger.info("test_take_09")
data1 = ds.GeneratorDataset(generator, ["data"])
@ -180,7 +180,7 @@ def test_take_09():
def test_take_10():
"""
Test take: repeat count is -1, and read the whole dataset, take before repeat
Test take: take count is -1, and read the whole dataset, take before repeat
"""
logger.info("test_take_10")
data1 = ds.GeneratorDataset(generator, ["data"])
@ -341,6 +341,18 @@ def test_take_18():
assert sum([1 for _ in data1]) == 2
def test_take_19():
"""
Test take: take is after batch, that mean take(N), N refer to batches num
"""
logger.info("test_take_19")
with pytest.raises(ValueError) as info:
data1 = ds.GeneratorDataset(generator, ["data"])
data1 = data1.batch(2)
data1 = data1.take(0)
assert "positive integer" in str(info.value)
if __name__ == '__main__':
test_take_01()
test_take_02()
@ -360,4 +372,5 @@ if __name__ == '__main__':
test_take_16()
test_take_17()
test_take_18()
test_take_19()
logger.info('== test take operation finished ==')