forked from mindspore-Ecosystem/mindspore
Remove Repeat(1),Take(-1), and Skip(0) in NodeRemovalPass
This commit is contained in:
parent
3280474d71
commit
4cb78f2e03
|
@ -490,12 +490,6 @@ RenameDataset::RenameDataset(std::shared_ptr<Dataset> input, const std::vector<s
|
|||
#endif
|
||||
|
||||
RepeatDataset::RepeatDataset(std::shared_ptr<Dataset> input, int32_t count) {
|
||||
// Workaround for repeat == 1, do not inject repeat.
|
||||
if (count == 1) {
|
||||
ir_node_ = input->IRNode();
|
||||
return;
|
||||
}
|
||||
|
||||
auto ds = std::make_shared<RepeatNode>(input->IRNode(), count);
|
||||
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
|
@ -516,13 +510,6 @@ SkipDataset::SkipDataset(std::shared_ptr<Dataset> input, int32_t count) {
|
|||
}
|
||||
|
||||
TakeDataset::TakeDataset(std::shared_ptr<Dataset> input, int32_t count) {
|
||||
// If count is greater than the number of element in dataset or equal to -1,
|
||||
// all the element in dataset will be taken
|
||||
if (count == -1) {
|
||||
ir_node_ = input->IRNode();
|
||||
return;
|
||||
}
|
||||
|
||||
auto ds = std::make_shared<TakeNode>(input->IRNode(), count);
|
||||
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
|
|
|
@ -66,7 +66,7 @@ class ColDescriptor {
|
|||
/// an unknown dimension, then the output shape returned shall resolve dimensions as needed.
|
||||
/// \param[in] num_elements - The number of elements in the data for a Tensor
|
||||
/// \param[inout] out_shape - The materialized output Tensor shape
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
Status MaterializeTensorShape(int32_t num_elements, TensorShape *out_shape) const;
|
||||
|
||||
/// \brief << Stream output operator overload
|
||||
|
@ -124,13 +124,13 @@ class DataSchema {
|
|||
/// \brief Parses a schema json file and populates the columns and meta info.
|
||||
/// \param[in] schema_file_path - the schema file that has the column's info to load
|
||||
/// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns.
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
Status LoadSchemaFile(const std::string &schema_file_path, const std::vector<std::string> &columns_to_load);
|
||||
|
||||
/// \brief Parses a schema JSON string and populates the columns and meta info.
|
||||
/// \param[in] schema_json_string - the schema file that has the column's info to load
|
||||
/// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns.
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
Status LoadSchemaString(const std::string &schema_json_string, const std::vector<std::string> &columns_to_load);
|
||||
|
||||
/// \brief A print method typically used for debugging
|
||||
|
@ -148,7 +148,7 @@ class DataSchema {
|
|||
|
||||
/// \brief Adds a column descriptor to the schema
|
||||
/// \param[in] cd - The ColDescriptor to add
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
Status AddColumn(const ColDescriptor &cd);
|
||||
|
||||
/// \brief getter
|
||||
|
@ -169,7 +169,7 @@ class DataSchema {
|
|||
|
||||
/// \brief Loops through all columns in the schema and returns a map with the column name to column index number.
|
||||
/// \param[inout] out_column_name_map - The output map of columns names to column index
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
Status GetColumnNameMap(std::unordered_map<std::string, int32_t> *out_column_name_map);
|
||||
|
||||
private:
|
||||
|
@ -177,7 +177,7 @@ class DataSchema {
|
|||
/// does not follow any particular order (json standard does not enforce any ordering protocol).
|
||||
/// This one produces a schema that contains all of the columns from the schema file.
|
||||
/// \param[in] column_tree - The nlohmann tree from the json file to parse
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
Status AnyOrderLoad(nlohmann::json column_tree);
|
||||
|
||||
/// \brief Internal helper function. For each input column name, perform a lookup to the json document to
|
||||
|
@ -185,18 +185,18 @@ class DataSchema {
|
|||
/// descriptor and add to the schema in the order in which the input column names are given.
|
||||
/// \param[in] column_tree - The nlohmann tree from the json file to parse
|
||||
/// \param[in] columns_to_load - list of strings for the columns to add to the schema
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
Status ColumnOrderLoad(nlohmann::json column_tree, const std::vector<std::string> &columns_to_load);
|
||||
|
||||
/// \brief Internal helper function. Given the json tree for a given column, load it into our schema.
|
||||
/// \param[in] columnTree - The nlohmann child tree for a given column to load.
|
||||
/// \param[in] col_name - The string name of the column for that subtree.
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
Status ColumnLoad(nlohmann::json column_child_tree, const std::string &col_name);
|
||||
|
||||
/// \brief Internal helper function. Performs sanity checks on the json file setup.
|
||||
/// \param[in] js - The nlohmann tree for the schema file
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
Status PreLoadExceptionCheck(const nlohmann::json &js);
|
||||
|
||||
std::vector<ColDescriptor> col_descs_; // Vector of column descriptors
|
||||
|
|
|
@ -53,7 +53,7 @@ class IteratorBase {
|
|||
// functionality exists in the derived versions of this function.
|
||||
// @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data
|
||||
// messages are encountered (such as eoe or eof), then an empty TensorRow is returned back.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
// @note The position of a Tensor/column might be different from the initial column order
|
||||
// in corresponding Dataset Op. User must be aware that MapOp, ZipOps, and others might change
|
||||
// the column ordering.
|
||||
|
@ -97,17 +97,17 @@ class DatasetIterator : public IteratorBase {
|
|||
// from the tree root node directly.
|
||||
// @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data
|
||||
// messages are encountered (such as eoe or eof), then an empty TensorRow is returned back.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status FetchNextTensorRow(TensorRow *out_row) override;
|
||||
|
||||
// Fetches the next tensor row into device row, and returns it's shape.
|
||||
// @param out_shapes - A vector of tensor shapes (one shape per column)
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetOutputShapes(std::vector<TensorShape> *out_shapes);
|
||||
|
||||
// Fetches the next tensor row into device row, and returns it's shape.
|
||||
// @param outShapes - A vector of tensor shapes (one shape per column)
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetOutputTypes(std::vector<DataType> *out_types);
|
||||
|
||||
// Getter
|
||||
|
@ -140,12 +140,12 @@ class ChildIterator : public IteratorBase {
|
|||
// only from the child/worker id as given from the constructor.
|
||||
// @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data
|
||||
// messages are encountered (such as eoe or eof), then an empty TensorRow is returned back.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status FetchNextTensorRow(TensorRow *out_row) override;
|
||||
|
||||
// This function drains buffer until next eoe has been received.
|
||||
// It will be a no-op if the previous row returned is empty.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Drain();
|
||||
|
||||
// Getter
|
||||
|
|
|
@ -134,7 +134,7 @@ class BarrierOp : public PipelineOp {
|
|||
// Class functor operator () override.
|
||||
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// Handles preprocessing of the main loop, used when starting new epoch
|
||||
|
|
|
@ -112,12 +112,12 @@ class BatchOp : public ParallelOp {
|
|||
#endif
|
||||
|
||||
// @param std::shared_ptr<BatchOp> *ptr pointer to shared_ptr, actual return arg
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Build(std::shared_ptr<BatchOp> *);
|
||||
|
||||
private:
|
||||
// Sanity check for builder class args
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status SanityCheck();
|
||||
|
||||
bool builder_drop_;
|
||||
|
@ -167,11 +167,11 @@ class BatchOp : public ParallelOp {
|
|||
~BatchOp() {}
|
||||
|
||||
// @param int32_t workerId
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status EofReceived(int32_t) override;
|
||||
|
||||
// @param int32_t workerId
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status EoeReceived(int32_t) override;
|
||||
|
||||
// A print method typically used for debugging
|
||||
|
@ -190,7 +190,7 @@ class BatchOp : public ParallelOp {
|
|||
}
|
||||
|
||||
// Main loop of batch
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
|
@ -214,14 +214,14 @@ class BatchOp : public ParallelOp {
|
|||
// @param const std::unique_ptr<TensorQTable> *dest - dest_table to hold batched rows
|
||||
// @param int32_t size - batch_size
|
||||
// @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
static Status BatchRows(const std::unique_ptr<TensorQTable> *src, const std::unique_ptr<TensorQTable> *dest,
|
||||
dsize_t batch_size);
|
||||
|
||||
// @param table
|
||||
// @param const PadInfo &pad_info pad info
|
||||
// @param const std::unordered_map<std::string, int32_t>& column_name_id_map - column names to index mapping
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info,
|
||||
const std::unordered_map<std::string, int32_t> &column_name_id_map);
|
||||
|
||||
|
@ -233,18 +233,18 @@ class BatchOp : public ParallelOp {
|
|||
private:
|
||||
// Worker thread for doing the memcpy of batch
|
||||
// @param int32_t param workerId
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
// Generate buffer with batched tensors
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair,
|
||||
std::unique_ptr<DataBuffer> *db);
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
// Function that calls pyfunc to perform map on batch
|
||||
// @param (std::pair<std::unique_ptr<TensorQTable>, batch_stats> *table_pair - contains un-batched tensor
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status MapColumns(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> *table_pair);
|
||||
#endif
|
||||
|
||||
|
@ -253,7 +253,7 @@ class BatchOp : public ParallelOp {
|
|||
// @param std::set<int32_t> *cols, col ids to perform pad on
|
||||
// @param std::vector<float> *vals, default padding value for each column
|
||||
// @param std::vector<std::vector<dsize_t>> *shapes, padding shape specified by user
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
static Status UnpackPadInfo(const PadInfo &pad_info,
|
||||
const std::unordered_map<std::string, int32_t> &column_name_id_map,
|
||||
std::set<int32_t> *pad_cols, std::vector<std::shared_ptr<Tensor>> *pad_vals,
|
||||
|
@ -264,20 +264,20 @@ class BatchOp : public ParallelOp {
|
|||
int32_t num_consumers() const override { return 1; }
|
||||
|
||||
// get the batch size for next batch
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetBatchSize(int32_t *batch_size, CBatchInfo info);
|
||||
|
||||
// Do the initialization of all queues then start all worker threads
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LaunchThreadsAndInitOp();
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
// Invoke batch size function with current BatchInfo to generate batch size.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info);
|
||||
|
||||
// Invoke batch map function with current BatchInfo to generate tensors to batch.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info);
|
||||
#endif
|
||||
|
||||
|
|
|
@ -107,7 +107,7 @@ class BucketBatchByLengthOp : public PipelineOp {
|
|||
|
||||
// Might need to batch remaining buckets after receiving eoe, so override this method.
|
||||
// @param int32_t workerId
|
||||
// @return Status - The error code returned
|
||||
// @return Status The status code returned
|
||||
Status EoeReceived(int32_t) override;
|
||||
|
||||
std::string Name() const override { return kBucketBatchByLengthOp; }
|
||||
|
@ -123,7 +123,7 @@ class BucketBatchByLengthOp : public PipelineOp {
|
|||
}
|
||||
|
||||
// Main loop of batch
|
||||
// @return Status - The error code returned
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -104,7 +104,7 @@ class BuildSentencePieceVocabOp : public PipelineOp {
|
|||
|
||||
// The builder "build" method creates the final object.
|
||||
// @param std::shared_ptr<BuildVocabOp> *op - DatasetOp
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Build(std::shared_ptr<BuildSentencePieceVocabOp> *op);
|
||||
|
||||
private:
|
||||
|
|
|
@ -110,7 +110,7 @@ class BuildVocabOp : public ParallelOp {
|
|||
|
||||
// The builder "build" method creates the final object.
|
||||
// @param std::shared_ptr<BuildVocabOp> *op - DatasetOp
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Build(std::shared_ptr<BuildVocabOp> *op);
|
||||
|
||||
private:
|
||||
|
|
|
@ -53,7 +53,7 @@ class CacheBase : public ParallelOp {
|
|||
/// \brief Overrides base class reset method. When an operator does a reset, it cleans up any state
|
||||
/// info from it's previous execution and then initializes itself so that it can be executed
|
||||
/// again.
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
Status Reset() override;
|
||||
|
||||
/// \brief A print method typically used for debugging
|
||||
|
|
|
@ -80,7 +80,7 @@ class CacheLookupOp : public CacheBase, public SamplerRT {
|
|||
std::shared_ptr<SamplerRT> build_sampler_;
|
||||
|
||||
// Check if the required parameters are set by the builder.
|
||||
// \return Status The error code return
|
||||
// \return Status The status code returned
|
||||
Status SanityCheck() const;
|
||||
};
|
||||
/// \brief Constructor
|
||||
|
|
|
@ -136,7 +136,7 @@ class CacheMergeOp : public ParallelOp {
|
|||
std::shared_ptr<SamplerRT> build_sampler_;
|
||||
|
||||
/// Check if the required parameters are set by the builder.
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status SanityCheck() const;
|
||||
};
|
||||
|
||||
|
@ -189,7 +189,7 @@ class CacheMergeOp : public ParallelOp {
|
|||
|
||||
/// \brief Base-class override for handling cases when an eof is received.
|
||||
/// \param worker_id - The worker id
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
Status EofReceived(int32_t worker_id) override;
|
||||
|
||||
protected:
|
||||
|
|
|
@ -99,7 +99,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
|
|||
std::shared_ptr<SamplerRT> build_sampler_;
|
||||
|
||||
/// \brief Check if the required parameters are set by the builder.
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status SanityCheck() const;
|
||||
};
|
||||
|
||||
|
@ -119,7 +119,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
|
|||
/// \brief Base-class override for special eoe handler.
|
||||
/// CacheOp must override this because it shall not perform default handling of eoe. Instead
|
||||
/// the CacheOp manages actions related to the end of the epoch.
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
Status EoeReceived(int32_t worker_id) override;
|
||||
/// \brief Base-class override for NodePass pre-visit acceptor
|
||||
/// \param[in] p The node to visit
|
||||
|
@ -133,7 +133,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
|
|||
Status Accept(NodePass *p, bool *modified) override;
|
||||
/// \brief Base-class override for handling cases when an eof is received.
|
||||
/// \param worker_id - The worker id
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
Status EofReceived(int32_t worker_id) override;
|
||||
Status operator()() override;
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
@ -159,7 +159,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
|
|||
Status CacheAllRows(int32_t worker_id);
|
||||
Status RegisterResources() override;
|
||||
/// \brief Private function for cache setup/init work just after construction
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status InitCache();
|
||||
};
|
||||
} // namespace dataset
|
||||
|
|
|
@ -94,7 +94,7 @@ class ConcatOp : public PipelineOp {
|
|||
|
||||
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// Op name getter
|
||||
|
|
|
@ -146,14 +146,14 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// DatasetOps operate by launching a thread (see ExecutionTree).
|
||||
/// This pure virtual version makes the requirement that derived classes must provide a functor
|
||||
/// that will execute their main runtime loop code.
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
virtual Status operator()() = 0;
|
||||
|
||||
/// \brief Gets the next buffer from the given child
|
||||
/// \notes See GetNextInput for similar function that has built-in message handling
|
||||
/// \param p_buffer - The shared pointer for the fetched buffer to return (by reference)
|
||||
/// \param worker_id - The worker id
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id) {
|
||||
return GetNextBuffer(p_buffer, worker_id, false);
|
||||
}
|
||||
|
@ -161,7 +161,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// \brief Gets the next buffer from the given child
|
||||
/// \notes See GetNextInput for similar function that has built-in message handling
|
||||
/// \param p_buffer - The shared pointer for the fetched buffer to return (by reference)
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer) { return GetNextBuffer(p_buffer, 0, false); }
|
||||
|
||||
/// \brief Gets the next buffer from the given child
|
||||
|
@ -169,7 +169,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// \param p_buffer - The shared pointer for the fetched buffer to return (by reference)
|
||||
/// \param worker_id - The worker id
|
||||
/// \param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
virtual Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe);
|
||||
|
||||
/// \brief Gets the next buffer from the given child . This function also has built-in eoe and eof
|
||||
|
@ -177,7 +177,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// those messages are received.
|
||||
/// \param p_buffer - The shared pointer for the fetched buffer to return (by reference)
|
||||
/// \param worker_id - The worker id
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
Status GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id = 0, int32_t child_index = 0);
|
||||
|
||||
/// \brief Gets the batch size
|
||||
|
@ -200,19 +200,19 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// The base class implementation simply flows the eoe message to output. Derived classes
|
||||
/// may override if they need to perform special eoe handling.
|
||||
/// \param worker_id - The worker id
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
virtual Status EoeReceived(int32_t worker_id);
|
||||
|
||||
/// \brief Performs handling for when an eof message is received.
|
||||
/// The base class implementation simply flows the eof message to output. Derived classes
|
||||
/// may override if they need to perform special eof handling.
|
||||
/// \param worker_id - The worker id
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
virtual Status EofReceived(int32_t worker_id);
|
||||
|
||||
/// \brief Derived classes may implement the reset function if the operator is stateful and needs
|
||||
/// specific reset handling that is not contained in this common code version of the reset
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
virtual Status Reset();
|
||||
|
||||
/// \brief During tree prepare phase, operators may have specific pre-operations to perform depending on
|
||||
|
|
|
@ -79,7 +79,7 @@ class FilterOp : public ParallelOp {
|
|||
|
||||
private:
|
||||
// Sanity check for builder class args.
|
||||
// @return Status - The error code return.
|
||||
// @return Status The status code returned.
|
||||
Status SanityCheck();
|
||||
std::vector<std::string> build_in_col_names_;
|
||||
std::shared_ptr<TensorOp> builder_predicate_func_;
|
||||
|
@ -105,15 +105,15 @@ class FilterOp : public ParallelOp {
|
|||
// Class functor operator () override.
|
||||
// All dataset ops operate by launching a thread (see ExecutionTree),This class functor will
|
||||
// provide the master loop that drives the logic for performing the work.
|
||||
// @return Status The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// @param int32_t workerId.
|
||||
// @return Status - The error code return.
|
||||
// @return Status The status code returned.
|
||||
Status EofReceived(int32_t) override;
|
||||
|
||||
// @param int32_t workerId.
|
||||
// @return Status - The error code return.
|
||||
// @return Status The status code returned.
|
||||
Status EoeReceived(int32_t) override;
|
||||
|
||||
// A print method typically used for debugging.
|
||||
|
@ -151,34 +151,34 @@ class FilterOp : public ParallelOp {
|
|||
// logic of FilterOp, getting the data from previous Op, validating user specified column names,
|
||||
// applying predicate to each of the data, filter the data when predicate result is false.
|
||||
// @param worker_id The id assigned to this thread/worker upon creation.
|
||||
// @return Status The error code return.
|
||||
// @return Status The status code returned
|
||||
Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_
|
||||
|
||||
// Filter the data by predicate function .
|
||||
// @param in_buffer input data buffer.
|
||||
// @param to_proess_indices Indices of columns to be processed.
|
||||
// @param out data buffer that are filtered by predicate.
|
||||
// @return Status The error code return.
|
||||
// @return Status The status code returned
|
||||
Status WorkerCompute(DataBuffer *in_buffer, std::unique_ptr<TensorQTable> *out);
|
||||
|
||||
// Collector databuffer.
|
||||
// @return Status The error code return.
|
||||
// @return Status The status code returned
|
||||
Status Collector();
|
||||
|
||||
// @param input tensor vector.
|
||||
// @return Status - The error code return.
|
||||
// @return Status The status code returned.
|
||||
Status CheckInput(const TensorRow &input) const;
|
||||
|
||||
// Invoke python func.
|
||||
// @param input tensor vector.
|
||||
// @param the result of predicate.
|
||||
// @return Status - The error code return.
|
||||
// @return Status The status code returned.
|
||||
Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate);
|
||||
|
||||
// Private function for validating if each of the user specified input column names
|
||||
// exist in the DataBuffer.
|
||||
// @param input_columns The vector of input column names used in the current thread.
|
||||
// @return Status The error code return.
|
||||
// @return Status The status code returned
|
||||
Status ValidateInColumns(const std::vector<std::string> *input_columns);
|
||||
|
||||
// Private function for checking the column legality
|
||||
|
|
|
@ -133,7 +133,7 @@ class MapOp : public ParallelOp {
|
|||
int32_t build_op_connector_size_;
|
||||
|
||||
// Check if the required parameters are set by the builder.
|
||||
// @return Status The error code return
|
||||
// @return Status The status code returned
|
||||
Status sanityCheck() const;
|
||||
};
|
||||
|
||||
|
@ -170,7 +170,7 @@ class MapOp : public ParallelOp {
|
|||
// provide the master loop that drives the logic for performing the work
|
||||
// This main thread creates local queues, pulls databuffers from the previous
|
||||
// op's Connector and distributes them to the local queues. Workers pull from the local queues.
|
||||
// @return Status The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// Getter
|
||||
|
@ -239,7 +239,7 @@ class MapOp : public ParallelOp {
|
|||
// applying a list of TensorOps to each of the data, process the results and then
|
||||
// pushing them back to MapOp's output Connector to be fetched by the next Op.
|
||||
// @param worker_id The id assigned to this thread/worker upon creation.
|
||||
// @return Status The error code return
|
||||
// @return Status The status code returned
|
||||
Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_
|
||||
|
||||
// Private function for worker thread to perform TensorOp's compute function and get the result.
|
||||
|
|
|
@ -89,7 +89,7 @@ class ParallelOp : public DatasetOp {
|
|||
}
|
||||
|
||||
// Override base class reset to provide reset actions specific to the ParallelOp class.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Reset() override;
|
||||
|
||||
// Getter
|
||||
|
@ -115,7 +115,7 @@ class ParallelOp : public DatasetOp {
|
|||
protected:
|
||||
// Interface for derived classes to implement. All derived classes must provide the entry
|
||||
// function with the main execution loop for worker threads.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status WorkerEntry(int32_t workerId) = 0;
|
||||
|
||||
/// This function is only intended to be called by CallbackManager within the master thread of ParallelOp
|
||||
|
|
|
@ -75,7 +75,7 @@ class ProjectOp : public PipelineOp {
|
|||
// However, the ProjectOp is defined as a inlined operator, so it is invalid to launch the
|
||||
// functor since this op runs inlined inside another operator. The function is overloaded to
|
||||
// ensure that it is not called by mistake (it will generate an error).
|
||||
// @return Status - The error code returned.
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// Gets a buffer from the child node and projects that buffer. The caller is typically our parent node.
|
||||
|
@ -93,12 +93,12 @@ class ProjectOp : public PipelineOp {
|
|||
|
||||
// Base-class override for special eoe handler.
|
||||
// Inline operators must override this because there is no connector to push eoe onto.
|
||||
// @return Status - The error code returned.
|
||||
// @return Status The status code returned
|
||||
Status EoeReceived(int32_t worker_id) override;
|
||||
|
||||
// Base-class override for special eof handler.
|
||||
// Inline operators must override this because there is no connector to push eof onto.
|
||||
// @return Status - The error code returned.
|
||||
// @return Status The status code returned
|
||||
Status EofReceived(int32_t worker_id) override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
|
|
|
@ -107,7 +107,7 @@ class RenameOp : public PipelineOp {
|
|||
// Class functor operator () override.
|
||||
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
|
|
|
@ -78,7 +78,7 @@ class RepeatOp : public PipelineOp {
|
|||
// However, the RepeatOp is defined as a inlined operator, so it is invalid to launch the
|
||||
// functor since this op runs inlined inside another operator. The function is overloaded to
|
||||
// ensure that it is not called by mistake (it will generate an error).
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// This function returns the buffer that is at the top of our output connector. The caller is
|
||||
|
@ -90,7 +90,7 @@ class RepeatOp : public PipelineOp {
|
|||
// @param p_buffer - output pointer to the buffer that it will fetch.
|
||||
// @param worker_id - The worker id
|
||||
// @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override;
|
||||
|
||||
// Base-class override for handling cases when an eoe is received.
|
||||
|
@ -130,7 +130,7 @@ class RepeatOp : public PipelineOp {
|
|||
int32_t num_repeats() { return num_repeats_; }
|
||||
|
||||
/// \brief reset Op
|
||||
/// \@return Status - The error code return
|
||||
/// \@return Status The status code returned
|
||||
Status Reset() override;
|
||||
|
||||
int64_t GetTreeRepeatCount() override;
|
||||
|
|
|
@ -146,13 +146,13 @@ class ShuffleOp : public PipelineOp {
|
|||
// Class functor operator () override.
|
||||
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// Base-class override for special eoe handler.
|
||||
// ShuffleOp must override this because it shall not perform default handling of eoe. Instead
|
||||
// the ShuffleOp needs to manage actions related to the end of the epoch itself.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status EoeReceived(int32_t worker_id) override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
|
@ -167,17 +167,17 @@ class ShuffleOp : public PipelineOp {
|
|||
|
||||
private:
|
||||
// Private function to add a new row to the shuffle buffer.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status AddRowToShuffleBuffer(TensorRow new_shuffle_row);
|
||||
|
||||
// Private function to populate the shuffle buffer initially by fetching from the child output
|
||||
// connector until the shuffle buffer is full (or there is no more data coming).
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status InitShuffleBuffer();
|
||||
|
||||
// Private function to re-init the shuffle op for another epoch. Shuffle op calls this by
|
||||
// itself rather than waiting for the reset driven from operators above it in the pipeline.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status SelfReset();
|
||||
|
||||
int32_t shuffle_size_; // User config for the size of the shuffle buffer (number of rows)
|
||||
|
|
|
@ -63,7 +63,7 @@ class SkipOp : public PipelineOp {
|
|||
// Class functor operator () override.
|
||||
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// Base-class override for handling cases when an eoe is received.
|
||||
|
|
|
@ -130,12 +130,12 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
|
|||
}
|
||||
|
||||
/// \brief Check validity of input args
|
||||
/// \return - The error code returned
|
||||
/// \return Status The status code returned
|
||||
Status SanityCheck();
|
||||
|
||||
/// \brief The builder "build" method creates the final object.
|
||||
/// \param[inout] std::shared_ptr<AlbumOp> *op - DatasetOp
|
||||
/// \return - The error code returned
|
||||
/// \return Status The status code returned
|
||||
Status Build(std::shared_ptr<AlbumOp> *op);
|
||||
|
||||
private:
|
||||
|
@ -168,18 +168,18 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
|
|||
~AlbumOp() = default;
|
||||
|
||||
/// \brief Initialize AlbumOp related var, calls the function to walk all files
|
||||
/// \return - The error code returned
|
||||
/// \return Status The status code returned
|
||||
Status PrescanEntry();
|
||||
|
||||
/// \brief Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
|
||||
/// \param[in] int32_t workerId - id of each worker
|
||||
/// \return Status - The error code returned
|
||||
/// \return Status The status code returned
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
/// \brief Main Loop of AlbumOp
|
||||
/// Master thread: Fill IOBlockQueue, then goes to sleep
|
||||
/// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
|
||||
/// \return Status - The error code returned
|
||||
/// \return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
/// \brief A print method typically used for debugging
|
||||
|
@ -204,93 +204,93 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
private:
|
||||
/// \brief Initialize Sampler, calls sampler->Init() within
|
||||
/// \return Status The error code returned
|
||||
/// \return Status The status code returned
|
||||
Status InitSampler();
|
||||
|
||||
/// \brief Load image to tensor row
|
||||
/// \param[in] image_file Image name of file
|
||||
/// \param[in] col_num Column num in schema
|
||||
/// \param[inout] row Tensor row to push to
|
||||
/// \return Status The error code returned
|
||||
/// \return Status The status code returned
|
||||
Status LoadImageTensor(const std::string &image_file, uint32_t col_num, TensorRow *row);
|
||||
|
||||
/// \brief Load vector of ints to tensor, append tensor to tensor row
|
||||
/// \param[in] json_obj Json object containing multi-dimensional label
|
||||
/// \param[in] col_num Column num in schema
|
||||
/// \param[inout] row Tensor row to push to
|
||||
/// \return Status The error code returned
|
||||
/// \return Status The status code returned
|
||||
Status LoadIntArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
|
||||
|
||||
/// \brief Load vector of floatss to tensor, append tensor to tensor row
|
||||
/// \param[in] json_obj Json object containing array data
|
||||
/// \param[in] col_num Column num in schema
|
||||
/// \param[inout] row Tensor row to push to
|
||||
/// \return Status The error code returned
|
||||
/// \return Status The status code returned
|
||||
Status LoadFloatArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
|
||||
|
||||
/// \brief Load string array into a tensor, append tensor to tensor row
|
||||
/// \param[in] json_obj Json object containing string tensor
|
||||
/// \param[in] col_num Column num in schema
|
||||
/// \param[inout] row Tensor row to push to
|
||||
/// \return Status The error code returned
|
||||
/// \return Status The status code returned
|
||||
Status LoadStringArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
|
||||
|
||||
/// \brief Load string into a tensor, append tensor to tensor row
|
||||
/// \param[in] json_obj Json object containing string tensor
|
||||
/// \param[in] col_num Column num in schema
|
||||
/// \param[inout] row Tensor row to push to
|
||||
/// \return Status The error code returned
|
||||
/// \return Status The status code returned
|
||||
Status LoadStringTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
|
||||
|
||||
/// \brief Load float value to tensor row
|
||||
/// \param[in] json_obj Json object containing float
|
||||
/// \param[in] col_num Column num in schema
|
||||
/// \param[inout] row Tensor row to push to
|
||||
/// \return Status The error code returned
|
||||
/// \return Status The status code returned
|
||||
Status LoadFloatTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
|
||||
|
||||
/// \brief Load int value to tensor row
|
||||
/// \param[in] json_obj Json object containing int
|
||||
/// \param[in] col_num Column num in schema
|
||||
/// \param[inout] row Tensor row to push to
|
||||
/// \return Status The error code returned
|
||||
/// \return Status The status code returned
|
||||
Status LoadIntTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
|
||||
|
||||
/// \brief Load emtpy tensor to tensor row
|
||||
/// \param[in] col_num Column num in schema
|
||||
/// \param[inout] row Tensor row to push to
|
||||
/// \return Status The error code returned
|
||||
/// \return Status The status code returned
|
||||
Status LoadEmptyTensor(uint32_t col_num, TensorRow *row);
|
||||
|
||||
/// \brief Load id from file name to tensor row
|
||||
/// \param[in] file The file name to get ID from
|
||||
/// \param[in] col_num Column num in schema
|
||||
/// \param[inout] row Tensor row to push to
|
||||
/// \return Status The error code returned
|
||||
/// \return Status The status code returned
|
||||
Status LoadIDTensor(const std::string &file, uint32_t col_num, TensorRow *row);
|
||||
|
||||
/// \brief Load a tensor row according to a json file
|
||||
/// \param[in] row_id_type row_id - id for this tensor row
|
||||
/// \param[in] ImageColumns file Json file location
|
||||
/// \param[inout] TensorRow row Json content stored into a tensor row
|
||||
/// \return Status The error code returned
|
||||
/// \return Status The status code returned
|
||||
Status LoadTensorRow(row_id_type row_id, const std::string &file, TensorRow *row);
|
||||
|
||||
/// \param[in] const std::vector<int64_t> &keys Keys in ioblock
|
||||
/// \param[inout] std::unique_ptr<DataBuffer> db Databuffer to push to
|
||||
/// \return Status The error code returned
|
||||
/// \return Status The status code returned
|
||||
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);
|
||||
|
||||
/// \brief Called first when function is called
|
||||
/// \return Status The error code returned
|
||||
/// \return Status The status code returned
|
||||
Status LaunchThreadsAndInitOp();
|
||||
|
||||
/// \brief reset Op
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status Reset() override;
|
||||
|
||||
// Private function for computing the assignment of the column name map.
|
||||
// @return Status The error code returned
|
||||
// @return Status The status code returned
|
||||
Status ComputeColMap() override;
|
||||
|
||||
int32_t rows_per_buffer_;
|
||||
|
|
|
@ -116,12 +116,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
|
|||
return *this;
|
||||
}
|
||||
// Check validity of input args
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status SanityCheck();
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
// @param std::shared_ptr<CelebAOp> *op - DatasetOp
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Build(std::shared_ptr<CelebAOp> *op);
|
||||
|
||||
private:
|
||||
|
@ -151,12 +151,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
|
|||
// Main Loop of CelebAOp
|
||||
// Master thread: Fill IOBlockQueue, then goes to sleep
|
||||
// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
|
||||
// @param int32_t worker_id - id of each worker
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
// A print method typically used for debugging
|
||||
|
@ -166,7 +166,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
|
|||
|
||||
// Method in operator(), to fill IOBlockQueue
|
||||
// @param std::unique_ptr<DataBuffer> sampler_buffer - to fill IOBlockQueue
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer);
|
||||
|
||||
/// \brief Base-class override for NodePass visitor acceptor
|
||||
|
@ -199,14 +199,14 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
|
|||
|
||||
// @param const std::vector<int64_t> &keys - keys in ioblock
|
||||
// @param std::unique_ptr<DataBuffer> db
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);
|
||||
|
||||
// Load a tensor row according to a pair
|
||||
// @param row_id_type row_id - id for this tensor row
|
||||
// @param std::pair - <image_file,<label>>
|
||||
// @param TensorRow row - image & label read into this tensor row
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LoadTensorRow(row_id_type row_id, const std::pair<std::string, std::vector<int32_t>> &image_label,
|
||||
TensorRow *row);
|
||||
|
||||
|
@ -215,7 +215,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
|
|||
bool CheckDatasetTypeValid();
|
||||
|
||||
// reset Op
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Reset() override;
|
||||
|
||||
// Private function for computing the assignment of the column name map.
|
||||
|
|
|
@ -109,12 +109,12 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
|||
}
|
||||
|
||||
// Check validity of input args
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status SanityCheck();
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
// @param std::shared_ptr<CifarOp> *op - DatasetOp
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Build(std::shared_ptr<CifarOp> *op);
|
||||
|
||||
private:
|
||||
|
@ -144,13 +144,13 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
|
||||
// @param uint32_t workerId - id of each worker
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
// Main Loop of CifarOp
|
||||
// Master thread: Fill IOBlockQueue, then goes to sleep
|
||||
// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// A print method typically used for debugging
|
||||
|
@ -177,18 +177,18 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status InitSampler();
|
||||
|
||||
// Load a tensor row according to a pair
|
||||
// @param uint64_t index - index need to load
|
||||
// @param TensorRow row - image & label read into this tensor row
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LoadTensorRow(uint64_t index, TensorRow *row);
|
||||
|
||||
// @param const std::vector<uint64_t> &keys - keys in ioblock
|
||||
// @param std::unique_ptr<DataBuffer> db
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);
|
||||
|
||||
// Read block data from cifar file
|
||||
|
@ -200,7 +200,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
|||
Status LaunchThreadsAndInitOp();
|
||||
|
||||
// reset Op
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Reset() override;
|
||||
|
||||
// Get cifar files in dir
|
||||
|
@ -221,7 +221,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
// Method derived from RandomAccess Op, enable Sampler to get all ids for each calss
|
||||
// @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
|
||||
|
||||
// Private function for computing the assignment of the column name map.
|
||||
|
|
|
@ -133,12 +133,12 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
|
|||
}
|
||||
|
||||
// Check validity of input args
|
||||
// @return = The error code return
|
||||
// @return Status The status code returned
|
||||
Status SanityCheck();
|
||||
|
||||
// The builder "Build" method creates the final object.
|
||||
// @param std::shared_ptr<CocoOp> *op - DatasetOp
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Build(std::shared_ptr<CocoOp> *op);
|
||||
|
||||
private:
|
||||
|
@ -173,13 +173,13 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
|
||||
// @param int32_t workerId - id of each worker
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
// Main Loop of CocoOp
|
||||
// Master thread: Fill IOBlockQueue, then goes to sleep
|
||||
// Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// A print method typically used for debugging
|
||||
|
@ -214,19 +214,19 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
|
|||
std::string Name() const override { return "CocoOp"; }
|
||||
|
||||
/// \brief Gets the class indexing
|
||||
/// \return Status - The status code return
|
||||
/// \return Status The status code returned
|
||||
Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) override;
|
||||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status InitSampler();
|
||||
|
||||
// Load a tensor row according to image id
|
||||
// @param row_id_type row_id - id for this tensor row
|
||||
// @param std::string image_id - image id
|
||||
// @param TensorRow row - image & target read into this tensor row
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row);
|
||||
|
||||
// Load a tensor row with vector which a vector to a tensor
|
||||
|
@ -235,7 +235,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
|
|||
// @param std::shared_ptr<Tensor> image - image tensor
|
||||
// @param std::shared_ptr<Tensor> coordinate - coordinate tensor
|
||||
// @param TensorRow row - image & target read into this tensor row
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LoadDetectionTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image,
|
||||
std::shared_ptr<Tensor> coordinate, TensorRow *trow);
|
||||
|
||||
|
@ -245,7 +245,7 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
|
|||
// @param std::shared_ptr<Tensor> image - image tensor
|
||||
// @param std::shared_ptr<Tensor> coordinate - coordinate tensor
|
||||
// @param TensorRow row - image & target read into this tensor row
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LoadSimpleTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image,
|
||||
std::shared_ptr<Tensor> coordinate, TensorRow *trow);
|
||||
|
||||
|
@ -255,69 +255,69 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
|
|||
// @param std::shared_ptr<Tensor> image - image tensor
|
||||
// @param std::shared_ptr<Tensor> coordinate - coordinate tensor
|
||||
// @param TensorRow row - image & target read into this tensor row
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LoadMixTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr<Tensor> image,
|
||||
std::shared_ptr<Tensor> coordinate, TensorRow *trow);
|
||||
|
||||
// @param const std::string &path - path to the image file
|
||||
// @param const ColDescriptor &col - contains tensor implementation and datatype
|
||||
// @param std::shared_ptr<Tensor> tensor - return
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor);
|
||||
|
||||
// @param const std::vector<uint64_t> &keys - keys in ioblock
|
||||
// @param std::unique_ptr<DataBuffer> db
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);
|
||||
|
||||
// Read annotation from Annotation folder
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status ParseAnnotationIds();
|
||||
|
||||
// @param const std::shared_ptr<Tensor> &sample_ids - sample ids of tensor
|
||||
// @param std::vector<int64_t> *keys - image id
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys);
|
||||
|
||||
// Called first when function is called
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LaunchThreadsAndInitOp();
|
||||
|
||||
// Reset dataset state
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Reset() override;
|
||||
|
||||
// @param nlohmann::json image_tree - image tree of json
|
||||
// @param std::vector<std::string> *image_vec - image id list of json
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status ImageColumnLoad(const nlohmann::json &image_tree, std::vector<std::string> *image_vec);
|
||||
|
||||
// @param nlohmann::json categories_tree - categories tree of json
|
||||
// return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status CategoriesColumnLoad(const nlohmann::json &categories_tree);
|
||||
|
||||
// @param nlohmann::json categories_tree - categories tree of json
|
||||
// @param const std::string &image_file - current image name in annotation
|
||||
// @param const int32_t &id - current unique id of annotation
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status DetectionColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id);
|
||||
|
||||
// @param nlohmann::json categories_tree - categories tree of json
|
||||
// @param const std::string &image_file - current image name in annotation
|
||||
// @param const int32_t &id - current unique id of annotation
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status StuffColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id);
|
||||
|
||||
// @param nlohmann::json categories_tree - categories tree of json
|
||||
// @param const std::string &image_file - current image name in annotation
|
||||
// @param const int32_t &id - current unique id of annotation
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status KeypointColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file, const int32_t &id);
|
||||
|
||||
// @param nlohmann::json categories_tree - categories tree of json
|
||||
// @param const std::string &image_file - current image name in annotation
|
||||
// @param const int32_t &image_id - current unique id of annotation
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status PanopticColumnLoad(const nlohmann::json &annotation_tree, const std::string &image_file,
|
||||
const int32_t &image_id);
|
||||
|
||||
|
|
|
@ -115,13 +115,13 @@ class GeneratorOp : public PipelineOp {
|
|||
// Class functor operator () override.
|
||||
// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// Overrides base class reset method. When an operator does a reset, it cleans up any state
|
||||
// info from it's previous execution and then initializes itself so that it can be executed
|
||||
// again.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Reset() override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
|
|
|
@ -135,12 +135,12 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
|
|||
}
|
||||
|
||||
// Check validity of input args
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status SanityCheck();
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
// @param std::shared_ptr<ImageFolderOp> *op - DatasetOp
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Build(std::shared_ptr<ImageFolderOp> *op);
|
||||
|
||||
private:
|
||||
|
@ -172,28 +172,28 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
// Initialize ImageFOlderOp related var, calls the function to walk all files
|
||||
// @param - std::string dir file directory to ImageNetFolder
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status PrescanMasterEntry(const std::string &dir);
|
||||
|
||||
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
|
||||
// @param int32_t workerId - id of each worker
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
|
||||
// @param int32_t workerId - id of each worker
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status PrescanWorkerEntry(int32_t worker_id);
|
||||
|
||||
// Main Loop of ImageFolderOp
|
||||
// Master thread: Fill IOBlockQueue, then goes to sleep
|
||||
// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
|
||||
// @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
|
||||
|
||||
// A print method typically used for debugging
|
||||
|
@ -224,19 +224,19 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status InitSampler();
|
||||
|
||||
// Load a tensor row according to a pair
|
||||
// @param row_id_type row_id - id for this tensor row
|
||||
// @param ImageLabelPair pair - <imagefile,label>
|
||||
// @param TensorRow row - image & label read into this tensor row
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LoadTensorRow(row_id_type row_id, ImageLabelPair pair, TensorRow *row);
|
||||
|
||||
// @param const std::vector<int64_t> &keys - keys in ioblock
|
||||
// @param std::unique_ptr<DataBuffer> db
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);
|
||||
|
||||
// @param std::string & dir - dir to walk all images
|
||||
|
@ -253,7 +253,7 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
|
|||
Status LaunchThreadsAndInitOp();
|
||||
|
||||
// reset Op
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Reset() override;
|
||||
|
||||
// Private function for computing the assignment of the column name map.
|
||||
|
|
|
@ -58,12 +58,12 @@ class IOBlock {
|
|||
// Fetches the first key from the block.
|
||||
// @note Only useful if you know the block only has 1 key.
|
||||
// @return A copy of the first key from the block
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetKey(int64_t *out_key) const;
|
||||
|
||||
// Fetches the list of keys from this block.
|
||||
// @param out_keys - A copy of the vector of keys from the block.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetKeys(std::vector<int64_t> *out_keys) const;
|
||||
|
||||
// Does this block have the eoe flag turned on?
|
||||
|
@ -110,7 +110,7 @@ class FilenameBlock : public IOBlock {
|
|||
// Gets the filename from the block using the provided index container
|
||||
// @param out_filename - The filename to add to the block
|
||||
// @param index - The index to perform lookup against
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetFilename(std::string *out_filename, const AutoIndexObj<std::string> &index) const;
|
||||
|
||||
// Get the start offset of file
|
||||
|
|
|
@ -110,12 +110,12 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
|
|||
}
|
||||
|
||||
// Check validity of input args
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status SanityCheck();
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
// @param std::shared_ptr<ManifestOp> *op - DatasetOp
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Build(std::shared_ptr<ManifestOp> *op);
|
||||
|
||||
private:
|
||||
|
@ -145,18 +145,18 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
|
||||
// @param int32_t worker_id - id of each worker
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
// Main Loop of ManifestOp
|
||||
// Master thread: Fill IOBlockQueue, then goes to sleep
|
||||
// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
|
||||
// @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
|
||||
|
||||
// A print method typically used for debugging
|
||||
|
@ -201,37 +201,37 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status InitSampler();
|
||||
|
||||
// Method in operator(), to fill IOBlockQueue
|
||||
// @param std::unique_ptr<DataBuffer> sampler_buffer - to fill IOBlockQueue
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer);
|
||||
|
||||
// Load a tensor row according to a pair
|
||||
// @param row_id_type row_id - id for this tensor row
|
||||
// @param std::pair<std::string, std::vector<std::string>> - <imagefile, <label1, label2...>>
|
||||
// @param TensorRow row - image & label read into this tensor row
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LoadTensorRow(row_id_type row_id, const std::pair<std::string, std::vector<std::string>> &data,
|
||||
TensorRow *row);
|
||||
|
||||
// @param const std::vector<int64_t> &keys - keys in ioblock
|
||||
// @param std::unique_ptr<DataBuffer> db
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);
|
||||
|
||||
// Parse manifest file to get image path and label and so on.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status ParseManifestFile();
|
||||
|
||||
// Called first when function is called
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LaunchThreadsAndInitOp();
|
||||
|
||||
// reset Op
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Reset() override;
|
||||
|
||||
// Check if image ia valid.Only support JPEG/PNG/GIF/BMP
|
||||
|
@ -239,7 +239,7 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
|
|||
Status CheckImageType(const std::string &file_name, bool *valid);
|
||||
|
||||
// Count label index,num rows and num samples
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status CountDatasetInfo();
|
||||
|
||||
// Private function for computing the assignment of the column name map.
|
||||
|
|
|
@ -164,13 +164,13 @@ class MindRecordOp : public ParallelOp {
|
|||
|
||||
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
|
||||
// @param int32_t workerId - id of each worker
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
// Class functor operator () override.
|
||||
// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// Called first when function is called
|
||||
|
@ -180,7 +180,7 @@ class MindRecordOp : public ParallelOp {
|
|||
// Overrides base class reset method. When an operator does a reset, it cleans up any state
|
||||
// info from it's previous execution and then initializes itself so that it can be executed
|
||||
// again.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Reset() override;
|
||||
|
||||
// Getter method
|
||||
|
|
|
@ -99,12 +99,12 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
|||
return *this;
|
||||
}
|
||||
// Check validity of input args
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status SanityCheck();
|
||||
|
||||
// The builder "Build" method creates the final object.
|
||||
// @param std::shared_ptr<MnistOp> *op - DatasetOp
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Build(std::shared_ptr<MnistOp> *op);
|
||||
|
||||
private:
|
||||
|
@ -133,18 +133,18 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
|
||||
// @param int32_t worker_id - id of each worker
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
// Main Loop of MnistOp
|
||||
// Master thread: Fill IOBlockQueue, then goes to sleep
|
||||
// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
|
||||
// @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
|
||||
|
||||
// A print method typically used for debugging
|
||||
|
@ -170,39 +170,39 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status InitSampler();
|
||||
|
||||
// Load a tensor row according to a pair
|
||||
// @param row_id_type row_id - id for this tensor row
|
||||
// @param ImageLabelPair pair - <imagefile,label>
|
||||
// @param TensorRow row - image & label read into this tensor row
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LoadTensorRow(row_id_type row_id, const MnistLabelPair &mnist_pair, TensorRow *row);
|
||||
|
||||
// @param const std::vector<int64_t> &keys - keys in ioblock
|
||||
// @param std::unique_ptr<DataBuffer> db
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);
|
||||
|
||||
// Iterate through all members in sampleIds and fill them into IOBlock.
|
||||
// @param std::shared_ptr<Tensor> sample_ids -
|
||||
// @param std::vector<int64_t> *keys - keys in ioblock
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys);
|
||||
|
||||
// Check image file stream.
|
||||
// @param const std::string *file_name - image file name
|
||||
// @param std::ifstream *image_reader - image file stream
|
||||
// @param uint32_t num_images - returns the number of images
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images);
|
||||
|
||||
// Check label stream.
|
||||
// @param const std::string &file_name - label file name
|
||||
// @param std::ifstream *label_reader - label file stream
|
||||
// @param uint32_t num_labels - returns the number of labels
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels);
|
||||
|
||||
// Read 4 bytes of data from a file stream.
|
||||
|
@ -219,23 +219,23 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
|
|||
// @param std::ifstream *image_reader - image file stream
|
||||
// @param std::ifstream *label_reader - label file stream
|
||||
// @param int64_t read_num - number of image to read
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index);
|
||||
|
||||
// Parse all mnist dataset files
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status ParseMnistData();
|
||||
|
||||
// Read all files in the directory
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status WalkAllFiles();
|
||||
|
||||
// Called first when function is called
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LaunchThreadsAndInitOp();
|
||||
|
||||
// reset Op
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Reset() override;
|
||||
|
||||
// Private function for computing the assignment of the column name map.
|
||||
|
|
|
@ -63,7 +63,7 @@ class RandomDataOp : public ParallelOp {
|
|||
/**
|
||||
* The build method that produces the instantiated RandomDataOp as a shared pointer
|
||||
* @param out_op - The output RandomDataOperator that was constructed
|
||||
* @return Status - The error code return
|
||||
* @return Status The status code returned
|
||||
*/
|
||||
Status Build(std::shared_ptr<RandomDataOp> *out_op);
|
||||
|
||||
|
@ -128,7 +128,7 @@ class RandomDataOp : public ParallelOp {
|
|||
private:
|
||||
/**
|
||||
* Check if the required parameters are set by the builder.
|
||||
* @return Status - The error code return
|
||||
* @return Status The status code returned
|
||||
*/
|
||||
Status SanityCheck() const;
|
||||
|
||||
|
@ -182,7 +182,7 @@ class RandomDataOp : public ParallelOp {
|
|||
* Class functor operator () override.
|
||||
* All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
|
||||
* provide the master loop that drives the logic for performing the work.
|
||||
* @return Status - The error code return
|
||||
* @return Status The status code returned
|
||||
*/
|
||||
Status operator()() override;
|
||||
|
||||
|
@ -190,7 +190,7 @@ class RandomDataOp : public ParallelOp {
|
|||
* Overrides base class reset method. When an operator does a reset, it cleans up any state
|
||||
* info from it's previous execution and then initializes itself so that it can be executed
|
||||
* again.
|
||||
* @return Status - The error code return
|
||||
* @return Status The status code returned
|
||||
*/
|
||||
Status Reset() override;
|
||||
|
||||
|
@ -207,7 +207,7 @@ class RandomDataOp : public ParallelOp {
|
|||
/**
|
||||
* The entry point code for when workers are launched
|
||||
* @param worker_id - The worker id
|
||||
* @return Status - The error code return
|
||||
* @return Status The status code returned
|
||||
*/
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
|
@ -219,7 +219,7 @@ class RandomDataOp : public ParallelOp {
|
|||
/**
|
||||
* Performs a synchronization between workers at the end of an epoch
|
||||
* @param worker_id - The worker id
|
||||
* @return Status - The error code return
|
||||
* @return Status The status code returned
|
||||
*/
|
||||
Status EpochSync(int32_t worker_id, bool *quitting);
|
||||
|
||||
|
@ -227,7 +227,7 @@ class RandomDataOp : public ParallelOp {
|
|||
* A helper function to stuff the tensor table into a buffer and send it to output connector
|
||||
* @param worker_id - The worker id
|
||||
* @param in_table - The tensor table to pack and send
|
||||
* @return Status - The error code return
|
||||
* @return Status The status code returned
|
||||
*/
|
||||
Status PackAndSend(int32_t worker_id, std::unique_ptr<TensorQTable> in_table);
|
||||
|
||||
|
@ -235,7 +235,7 @@ class RandomDataOp : public ParallelOp {
|
|||
* A helper function to create random data for the row
|
||||
* @param worker_id - The worker id
|
||||
* @param new_row - The output row to produce
|
||||
* @return Status - The error code return
|
||||
* @return Status The status code returned
|
||||
*/
|
||||
Status CreateRandomRow(int32_t worker_id, TensorRow *new_row);
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED
|
|||
|
||||
// @param std::unique_ptr<DataBuffer pBuffer
|
||||
// @param int32_t workerId
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
// first handshake between leaf source op and Sampler. This func will determine the amount of data
|
||||
|
@ -53,7 +53,7 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED
|
|||
Status InitSampler() override;
|
||||
|
||||
// for next epoch of sampleIds
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status ResetSampler() override;
|
||||
|
||||
// Printer for debugging purposes.
|
||||
|
|
|
@ -41,13 +41,13 @@ class PythonSamplerRT : public SamplerRT {
|
|||
Status InitSampler() override;
|
||||
|
||||
// for next epoch of sampleIds
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status ResetSampler() override;
|
||||
|
||||
// Op calls this to get next Buffer that contains all the sampleIds
|
||||
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op
|
||||
// @param int32_t workerId - not meant to be used
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
// Printer for debugging purposes.
|
||||
|
|
|
@ -40,14 +40,14 @@ class RandomSamplerRT : public SamplerRT {
|
|||
// Op calls this to get next Buffer that contains all the sampleIds
|
||||
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
|
||||
// @param int32_t workerId - not meant to be used
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
// meant to be called by base class or python
|
||||
Status InitSampler() override;
|
||||
|
||||
// for next epoch of sampleIds
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status ResetSampler() override;
|
||||
|
||||
void SamplerPrint(std::ostream &out, bool show_all) const override;
|
||||
|
|
|
@ -35,12 +35,12 @@ class RandomAccessOp {
|
|||
public:
|
||||
// Sampler get number of rows in the dataset
|
||||
// @param int64_t num - return number of rows for this dataset
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetNumRowsInDataset(int64_t *num_rows) const;
|
||||
|
||||
// sampler gets label , imageIds from corresponding Dataset Op, this function is unique to PK
|
||||
// @param std::map<int64_t, std::vector<int64_t>> * map
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *map) const {
|
||||
RETURN_STATUS_UNEXPECTED("GetClassIds needs to be override to support PK");
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ class SamplerRT {
|
|||
// @note It is Sampler responsibility to make sure that the id is not out of bound.
|
||||
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
|
||||
// @param int32_t workerId - not meant to be used
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) = 0;
|
||||
|
||||
// This function only called by python layer. Not needed by Android.
|
||||
|
@ -81,7 +81,7 @@ class SamplerRT {
|
|||
#endif
|
||||
|
||||
// for next epoch of sampleIds
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status ResetSampler() = 0;
|
||||
|
||||
// first handshake between leaf source op and Sampler. This func will determine the amount of data
|
||||
|
@ -114,13 +114,13 @@ class SamplerRT {
|
|||
|
||||
// Adds a sampler to become our child.
|
||||
// @param std::shared_ptr<DatasetOp> - The sampler to add as a child.
|
||||
// @return - The error code returned.
|
||||
// @return Status The status code returned
|
||||
Status AddChild(std::shared_ptr<SamplerRT> child);
|
||||
|
||||
// A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler
|
||||
// @param std::shared_ptr<Tensor>* sampleIds
|
||||
// @param int64_t numElements - must be a non 0 number
|
||||
// @return - The error code returned.
|
||||
// @return Status The status code returned
|
||||
Status CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements);
|
||||
|
||||
// A print method typically used for debugging
|
||||
|
@ -146,7 +146,7 @@ class SamplerRT {
|
|||
// associated id.
|
||||
// @param int64_t* out_associated_id - Out parameter, contains the associated id.
|
||||
// @param int64_t id - The id used as an index to get the associated child id.
|
||||
// @return - The error code returned.
|
||||
// @return Status The status code returned
|
||||
Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id);
|
||||
|
||||
protected:
|
||||
|
|
|
@ -40,13 +40,13 @@ class SequentialSamplerRT : public SamplerRT {
|
|||
Status InitSampler() override;
|
||||
|
||||
// for next epoch of sampleIds
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status ResetSampler() override;
|
||||
|
||||
// Op calls this to get next Buffer that contains all the sampleIds
|
||||
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to corresponding Dataset Op
|
||||
// @param int32_t workerId - not meant to be used
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
|
||||
|
||||
// Printer for debugging purposes.
|
||||
|
|
|
@ -132,12 +132,12 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
}
|
||||
|
||||
// Check validity of input args
|
||||
// @return = The error code return
|
||||
// @return Status The status code returned
|
||||
Status SanityCheck();
|
||||
|
||||
// The builder "Build" method creates the final object.
|
||||
// @param std::shared_ptr<VOCOp> *op - DatasetOp
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Build(std::shared_ptr<VOCOp> *op);
|
||||
|
||||
private:
|
||||
|
@ -173,13 +173,13 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
|
||||
// @param int32_t workerId - id of each worker
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
// Main Loop of VOCOp
|
||||
// Master thread: Fill IOBlockQueue, then goes to sleep
|
||||
// Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// A print method typically used for debugging
|
||||
|
@ -222,55 +222,55 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
private:
|
||||
// Initialize Sampler, calls sampler->Init() within
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status InitSampler();
|
||||
|
||||
// Load a tensor row according to image id
|
||||
// @param row_id_type row_id - id for this tensor row
|
||||
// @param std::string image_id - image id
|
||||
// @param TensorRow row - image & target read into this tensor row
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row);
|
||||
|
||||
// @param const std::string &path - path to the image file
|
||||
// @param const ColDescriptor &col - contains tensor implementation and datatype
|
||||
// @param std::shared_ptr<Tensor> tensor - return
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor);
|
||||
|
||||
// @param const std::string &path - path to the image file
|
||||
// @param TensorRow *row - return
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status ReadAnnotationToTensor(const std::string &path, TensorRow *row);
|
||||
|
||||
// @param const std::vector<uint64_t> &keys - keys in ioblock
|
||||
// @param std::unique_ptr<DataBuffer> db
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);
|
||||
|
||||
// Read image list from ImageSets
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status ParseImageIds();
|
||||
|
||||
// Read annotation from Annotation folder
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status ParseAnnotationIds();
|
||||
|
||||
// @param const std::string &path - path to annotation xml
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status ParseAnnotationBbox(const std::string &path);
|
||||
|
||||
// @param const std::shared_ptr<Tensor> &sample_ids - sample ids of tensor
|
||||
// @param std::vector<int64_t> *keys - image id
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys);
|
||||
|
||||
// Called first when function is called
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LaunchThreadsAndInitOp();
|
||||
|
||||
// Reset dataset state
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Reset() override;
|
||||
|
||||
// Private function for computing the assignment of the column name map.
|
||||
|
|
|
@ -75,7 +75,7 @@ class TakeOp : public PipelineOp {
|
|||
|
||||
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
|
|
|
@ -101,7 +101,7 @@ class ZipOp : public PipelineOp {
|
|||
// Class functor operator () override.
|
||||
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status operator()() override;
|
||||
|
||||
/// \brief Base-class override for NodePass pre-visit acceptor
|
||||
|
|
|
@ -226,7 +226,7 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
|
|||
// Compulsory transformation/action post optimization.
|
||||
// For example, repeatOp inlining
|
||||
//
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status ExecutionTree::Prepare(int32_t num_epochs, bool partial) {
|
||||
num_epochs_ = num_epochs;
|
||||
partially_prepare_ = partial;
|
||||
|
|
|
@ -115,16 +115,16 @@ class ExecutionTree {
|
|||
// provides it with a link to the tree. A node cannot form any relationships (parent/child) with
|
||||
// other nodes unless they are associated with the same tree.
|
||||
// @param op - The operator to associate
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status AssociateNode(const std::shared_ptr<DatasetOp> &op);
|
||||
|
||||
// Sets the root node of the tree
|
||||
// @param op - The operator to assign as root
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status AssignRoot(const std::shared_ptr<DatasetOp> &op);
|
||||
|
||||
// Start the execution of the tree
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Launch();
|
||||
|
||||
/// A print method typically used for debugging
|
||||
|
@ -155,7 +155,7 @@ class ExecutionTree {
|
|||
// wrapper for the TaskGroup handling that is stored inside the execution tree.
|
||||
// @param num_workers - The number of workers to launch
|
||||
// @param func - The function entry point that workers will execute
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name = "");
|
||||
|
||||
// Getter method
|
||||
|
@ -181,32 +181,32 @@ class ExecutionTree {
|
|||
// Compulsory transformation/action post optimization.
|
||||
// For example, repeatOp inlining
|
||||
//
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Prepare(int num_epochs = -1, bool partial = false);
|
||||
|
||||
// Compulsory transformation/action pre optimization.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status PreAction();
|
||||
|
||||
// Compulsory transformation/action post optimization.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status PostAction();
|
||||
|
||||
// Optimization transformation/action, optional.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Optimize();
|
||||
|
||||
// The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively
|
||||
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
|
||||
// it ready for execution.
|
||||
// @param Total number of epochs that will be run on this tree
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status PrepareDeprecated();
|
||||
|
||||
// Recursive function used during prepare phase to visit a node and drive any pre- and post-
|
||||
// node actions during a tree walk.
|
||||
// @param op - The dataset op to work on
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op);
|
||||
|
||||
// Return the pointer to the TaskGroup
|
||||
|
|
|
@ -51,7 +51,7 @@ class Edge {
|
|||
// Get the feature of a edge
|
||||
// @param FeatureType feature_type - type of feature
|
||||
// @param std::shared_ptr<Feature> *out_feature - Returned feature
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0;
|
||||
|
||||
// Get nodes on the edge
|
||||
|
@ -71,7 +71,7 @@ class Edge {
|
|||
|
||||
// Update feature of edge
|
||||
// @param std::shared_ptr<Feature> feature -
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status UpdateFeature(const std::shared_ptr<Feature> &feature) = 0;
|
||||
|
||||
protected:
|
||||
|
|
|
@ -47,19 +47,19 @@ class GraphData {
|
|||
// Get all nodes from the graph.
|
||||
// @param NodeType node_type - type of node
|
||||
// @param std::shared_ptr<Tensor> *out - Returned nodes id
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) = 0;
|
||||
|
||||
// Get all edges from the graph.
|
||||
// @param NodeType edge_type - type of edge
|
||||
// @param std::shared_ptr<Tensor> *out - Returned edge ids
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) = 0;
|
||||
|
||||
// Get the node id from the edge.
|
||||
// @param std::vector<EdgeIdType> edge_list - List of edges
|
||||
// @param std::shared_ptr<Tensor> *out - Returned node ids
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0;
|
||||
|
||||
// All neighbors of the acquisition node.
|
||||
|
@ -68,7 +68,7 @@ class GraphData {
|
|||
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
|
||||
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
|
||||
// is not enough, fill in tensor as -1.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||
std::shared_ptr<Tensor> *out) = 0;
|
||||
|
||||
|
@ -77,7 +77,7 @@ class GraphData {
|
|||
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
|
||||
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
|
||||
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
|
||||
const std::vector<NodeIdType> &neighbor_nums,
|
||||
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) = 0;
|
||||
|
@ -87,7 +87,7 @@ class GraphData {
|
|||
// @param NodeIdType samples_num - Number of neighbors sampled
|
||||
// @param NodeType neg_neighbor_type - The type of negative neighbor.
|
||||
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
||||
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) = 0;
|
||||
|
||||
|
@ -98,7 +98,7 @@ class GraphData {
|
|||
// @param float step_away_param - inout hyper parameter in node2vec algorithm
|
||||
// @param NodeIdType default_node - default node id
|
||||
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
|
||||
float step_home_param, float step_away_param, NodeIdType default_node,
|
||||
std::shared_ptr<Tensor> *out) = 0;
|
||||
|
@ -108,7 +108,7 @@ class GraphData {
|
|||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// @param TensorRow *out - Returned features
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
|
||||
TensorRow *out) = 0;
|
||||
|
||||
|
@ -117,7 +117,7 @@ class GraphData {
|
|||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// @param Tensor *out - Returned features
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
|
||||
TensorRow *out) = 0;
|
||||
|
||||
|
|
|
@ -57,19 +57,19 @@ class GraphDataClient : public GraphData {
|
|||
// Get all nodes from the graph.
|
||||
// @param NodeType node_type - type of node
|
||||
// @param std::shared_ptr<Tensor> *out - Returned nodes id
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// Get all edges from the graph.
|
||||
// @param NodeType edge_type - type of edge
|
||||
// @param std::shared_ptr<Tensor> *out - Returned edge ids
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// Get the node id from the edge.
|
||||
// @param std::vector<EdgeIdType> edge_list - List of edges
|
||||
// @param std::shared_ptr<Tensor> *out - Returned node ids
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// All neighbors of the acquisition node.
|
||||
|
@ -78,7 +78,7 @@ class GraphDataClient : public GraphData {
|
|||
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
|
||||
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
|
||||
// is not enough, fill in tensor as -1.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||
std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
|
@ -87,7 +87,7 @@ class GraphDataClient : public GraphData {
|
|||
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
|
||||
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
|
||||
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
|
||||
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
|
@ -96,7 +96,7 @@ class GraphDataClient : public GraphData {
|
|||
// @param NodeIdType samples_num - Number of neighbors sampled
|
||||
// @param NodeType neg_neighbor_type - The type of negative neighbor.
|
||||
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
||||
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
|
@ -107,7 +107,7 @@ class GraphDataClient : public GraphData {
|
|||
// @param float step_away_param - inout hyper parameter in node2vec algorithm
|
||||
// @param NodeIdType default_node - default node id
|
||||
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
|
||||
float step_home_param, float step_away_param, NodeIdType default_node,
|
||||
std::shared_ptr<Tensor> *out) override;
|
||||
|
@ -117,7 +117,7 @@ class GraphDataClient : public GraphData {
|
|||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// @param TensorRow *out - Returned features
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
|
||||
TensorRow *out) override;
|
||||
|
||||
|
@ -126,7 +126,7 @@ class GraphDataClient : public GraphData {
|
|||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// @param Tensor *out - Returned features
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
|
||||
TensorRow *out) override;
|
||||
|
||||
|
|
|
@ -51,19 +51,19 @@ class GraphDataImpl : public GraphData {
|
|||
// Get all nodes from the graph.
|
||||
// @param NodeType node_type - type of node
|
||||
// @param std::shared_ptr<Tensor> *out - Returned nodes id
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// Get all edges from the graph.
|
||||
// @param NodeType edge_type - type of edge
|
||||
// @param std::shared_ptr<Tensor> *out - Returned edge ids
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// Get the node id from the edge.
|
||||
// @param std::vector<EdgeIdType> edge_list - List of edges
|
||||
// @param std::shared_ptr<Tensor> *out - Returned node ids
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// All neighbors of the acquisition node.
|
||||
|
@ -72,7 +72,7 @@ class GraphDataImpl : public GraphData {
|
|||
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
|
||||
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
|
||||
// is not enough, fill in tensor as -1.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||
std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
|
@ -81,7 +81,7 @@ class GraphDataImpl : public GraphData {
|
|||
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
|
||||
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
|
||||
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
|
||||
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
|
@ -90,7 +90,7 @@ class GraphDataImpl : public GraphData {
|
|||
// @param NodeIdType samples_num - Number of neighbors sampled
|
||||
// @param NodeType neg_neighbor_type - The type of negative neighbor.
|
||||
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
||||
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
|
@ -101,7 +101,7 @@ class GraphDataImpl : public GraphData {
|
|||
// @param float step_away_param - inout hyper parameter in node2vec algorithm
|
||||
// @param NodeIdType default_node - default node id
|
||||
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
|
||||
float step_home_param, float step_away_param, NodeIdType default_node,
|
||||
std::shared_ptr<Tensor> *out) override;
|
||||
|
@ -111,7 +111,7 @@ class GraphDataImpl : public GraphData {
|
|||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// @param TensorRow *out - Returned features
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
|
||||
TensorRow *out) override;
|
||||
|
||||
|
@ -123,7 +123,7 @@ class GraphDataImpl : public GraphData {
|
|||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// @param Tensor *out - Returned features
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
|
||||
TensorRow *out) override;
|
||||
|
||||
|
@ -132,7 +132,7 @@ class GraphDataImpl : public GraphData {
|
|||
|
||||
// Get meta information of graph
|
||||
// @param MetaInfo *meta_info - Returned meta information
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetMetaInfo(MetaInfo *meta_info);
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
|
@ -202,14 +202,14 @@ class GraphDataImpl : public GraphData {
|
|||
};
|
||||
|
||||
// Load graph data from mindrecord file
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status LoadNodeAndEdge();
|
||||
|
||||
// Create Tensor By Vector
|
||||
// @param std::vector<std::vector<T>> &data -
|
||||
// @param DataType type -
|
||||
// @param std::shared_ptr<Tensor> *out -
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
template <typename T>
|
||||
Status CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type, std::shared_ptr<Tensor> *out);
|
||||
|
||||
|
@ -217,32 +217,32 @@ class GraphDataImpl : public GraphData {
|
|||
// @param std::vector<std::vector<T>> *data - To be completed vector
|
||||
// @param size_t max_size - The size of the completed vector
|
||||
// @param T default_value - Filled default
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
template <typename T>
|
||||
Status ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value);
|
||||
|
||||
// Get the default feature of a node
|
||||
// @param FeatureType feature_type -
|
||||
// @param std::shared_ptr<Feature> *out_feature - Returned feature
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature);
|
||||
|
||||
// Get the default feature of a edge
|
||||
// @param FeatureType feature_type -
|
||||
// @param std::shared_ptr<Feature> *out_feature - Returned feature
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature);
|
||||
|
||||
// Find node object using node id
|
||||
// @param NodeIdType id -
|
||||
// @param std::shared_ptr<Node> *node - Returned node object
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node);
|
||||
|
||||
// Find edge object using edge id
|
||||
// @param EdgeIdType id -
|
||||
// @param std::shared_ptr<Node> *edge - Returned edge object
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge);
|
||||
|
||||
// Negative sampling
|
||||
|
@ -250,7 +250,7 @@ class GraphDataImpl : public GraphData {
|
|||
// @param std::unordered_set<NodeIdType> &exclude_data - Data to be excluded
|
||||
// @param int32_t samples_num -
|
||||
// @param std::vector<NodeIdType> *out_samples - Sampling results returned
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status NegativeSample(const std::vector<NodeIdType> &data, const std::vector<NodeIdType> shuffled_ids,
|
||||
size_t *start_index, const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out_samples);
|
||||
|
|
|
@ -43,12 +43,12 @@ class LocalEdge : public Edge {
|
|||
// Get the feature of a edge
|
||||
// @param FeatureType feature_type - type of feature
|
||||
// @param std::shared_ptr<Feature> *out_feature - Returned feature
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override;
|
||||
|
||||
// Update feature of edge
|
||||
// @param std::shared_ptr<Feature> feature -
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status UpdateFeature(const std::shared_ptr<Feature> &feature) override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -40,13 +40,13 @@ class LocalNode : public Node {
|
|||
// Get the feature of a node
|
||||
// @param FeatureType feature_type - type of feature
|
||||
// @param std::shared_ptr<Feature> *out_feature - Returned feature
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override;
|
||||
|
||||
// Get the all neighbors of a node
|
||||
// @param NodeType neighbor_type - type of neighbor
|
||||
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors,
|
||||
bool exclude_itself = false) override;
|
||||
|
||||
|
@ -54,18 +54,18 @@ class LocalNode : public Node {
|
|||
// @param NodeType neighbor_type - type of neighbor
|
||||
// @param int32_t samples_num - Number of neighbors to be acquired
|
||||
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out_neighbors) override;
|
||||
|
||||
// Add neighbor of node
|
||||
// @param std::shared_ptr<Node> node -
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status AddNeighbor(const std::shared_ptr<Node> &node) override;
|
||||
|
||||
// Update feature of node
|
||||
// @param std::shared_ptr<Feature> feature -
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status UpdateFeature(const std::shared_ptr<Feature> &feature) override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -49,13 +49,13 @@ class Node {
|
|||
// Get the feature of a node
|
||||
// @param FeatureType feature_type - type of feature
|
||||
// @param std::shared_ptr<Feature> *out_feature - Returned feature
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0;
|
||||
|
||||
// Get the all neighbors of a node
|
||||
// @param NodeType neighbor_type - type of neighbor
|
||||
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors,
|
||||
bool exclude_itself = false) = 0;
|
||||
|
||||
|
@ -63,18 +63,18 @@ class Node {
|
|||
// @param NodeType neighbor_type - type of neighbor
|
||||
// @param int32_t samples_num - Number of neighbors to be acquired
|
||||
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out_neighbors) = 0;
|
||||
|
||||
// Add neighbor of node
|
||||
// @param std::shared_ptr<Node> node -
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status AddNeighbor(const std::shared_ptr<Node> &node) = 0;
|
||||
|
||||
// Update feature of node
|
||||
// @param std::shared_ptr<Feature> feature -
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
virtual Status UpdateFeature(const std::shared_ptr<Feature> &feature) = 0;
|
||||
|
||||
protected:
|
||||
|
|
|
@ -57,6 +57,10 @@ class RepeatNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Getter
|
||||
/// \return Number of cycles to repeat the execution
|
||||
const int32_t Count() const { return repeat_count_; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
|
|
|
@ -55,6 +55,10 @@ class SkipNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Getter
|
||||
/// \return Number of rows to skip
|
||||
const int32_t Count() const { return skip_count_; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
|
|
|
@ -55,6 +55,10 @@ class TakeNode : public DatasetNode {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Getter
|
||||
/// \return Number of rows to output
|
||||
const int32_t Count() const { return take_count_; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
|
|
|
@ -29,7 +29,7 @@ class TensorOpFusionPass : public NodePass {
|
|||
/// \brief Identifies and fuses tensor ops within MapOp
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] *modified indicates whether the node has been visited
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
|
||||
};
|
||||
} // namespace dataset
|
||||
|
|
|
@ -135,7 +135,7 @@ class IRTreePass : public IRPass {
|
|||
/// "modified" flag needs to be set to true if tree is modified during the pass execution.
|
||||
/// \param[inout] tree The tree to operate on.
|
||||
/// \param[inout] Indicate if the tree was modified.
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
virtual Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) { return Status::OK(); }
|
||||
};
|
||||
|
||||
|
@ -170,14 +170,14 @@ class IRNodePass : public IRPass {
|
|||
/// a tree traversal. "modified" flag needs to be set to true if node is modified during the pass execution
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[out] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
virtual Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); }
|
||||
|
||||
/// \brief Derived classes may implement the VisitAfter function to implement node level tree transformation
|
||||
/// "modified" flag needs to be set to true if node is modified during the pass execution
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[out] modified Indicator if the node was changed at all.
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
virtual Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) { return Status::OK(); }
|
||||
|
||||
// Visit()/VisitAfter() method to be overridden.
|
||||
|
@ -266,7 +266,7 @@ class TreePass : public Pass {
|
|||
/// "modified" flag needs to be set to true if tree is modified during the pass execution.
|
||||
/// \param[inout] tree The tree to operate on.
|
||||
/// \param[inout] Indicate of the tree was modified.
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); }
|
||||
};
|
||||
|
||||
|
@ -301,14 +301,14 @@ class NodePass : public Pass {
|
|||
/// a tree traversal. "modified" flag needs to be set to true if tree is modified during the pass execution
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[out] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
virtual Status PreRunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); }
|
||||
|
||||
/// \brief Derived classes may implement the RunOnNode function to implement node level tree transformation
|
||||
/// "modified" flag needs to be set to true if tree is modified during the pass execution
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[out] modified Indicator if the node was changed at all.
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
virtual Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); }
|
||||
|
||||
// Visit methods to be overridden.
|
||||
|
|
|
@ -41,80 +41,80 @@ class RepeatPass : public NodePass {
|
|||
/// \brief Identifies the subtree below this node as being in a repeated path of the tree.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Identifies the subtree below this node as being in a repeated path of the tree.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Identifies the subtree below this node as being in a cache merge path
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Identifies the subtree below this node as being cached
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Hooks up any identified eoe nodes under this repeat.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Hooks up any identified eoe nodes under this repeat.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override;
|
||||
|
||||
/// \brief CacheOp removes previous leaf ops and replaces them with itself
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Turns of the tracking for operations under merge op
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Saves the lookup up in case it needs to be referenced by a repeat
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Set the epoch count for DeviceQueue
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Special case for GeneratorOp
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;
|
||||
|
||||
/// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
|
||||
/// for use with a controlling repeat above it.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override;
|
||||
|
||||
private:
|
||||
/// \brief Adds an operator to the eoe operator stack save area
|
||||
/// \param op - The dataset op to work add to eoe stack
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op);
|
||||
|
||||
/// \brief Pops an operator from the eoe operator stack save area
|
||||
|
@ -127,7 +127,7 @@ class RepeatPass : public NodePass {
|
|||
|
||||
/// \brief Adds an operator to the cached operator stack save area
|
||||
/// \param op - The dataset op to work add to cached stack
|
||||
/// \return Status - The error code return
|
||||
/// \return Status The status code returned
|
||||
void AddToCachedOpStack(std::shared_ptr<DatasetOp> dataset_op);
|
||||
|
||||
/// \brief Pops an operator from the cached operator stack save area
|
||||
|
|
|
@ -38,123 +38,123 @@ class CacheErrorPass : public NodePass {
|
|||
/// \brief Identifies the subtree below this node as being cached
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Returns an error if ZipOp exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<ZipOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Returns an error if MapOp with non-deterministic TensorOps exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Returns an error if ConcatOp exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Returns an error if TakeOp exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Returns an error if SkipOp exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Returns an error if SkipOp exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<BatchOp> node, bool *modified) override;
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
/// \brief Returns an error if FilterOp exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Identifies the leaf dataset as being mappable
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Identifies the subtree above this node as not being cached
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Identifies and block repeat under cache scenarios
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -48,14 +48,14 @@ class CacheTransformPass : public TreePass {
|
|||
/// \brief Identifies the subtree below this node as a cached descendant tree.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Resets the tracking of the cache within the tree and assigns the operators that
|
||||
/// will be involved in a cache transformation
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
|
@ -63,95 +63,95 @@ class CacheTransformPass : public TreePass {
|
|||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override;
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Perform leaf node cache transform identifications
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override;
|
||||
#endif
|
||||
|
||||
|
@ -161,12 +161,12 @@ class CacheTransformPass : public TreePass {
|
|||
private:
|
||||
/// \brief Common code for mappable leaf setup.
|
||||
/// \param[in] node The leaf node performing setup work.
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
|
||||
|
||||
/// \brief Common code for non-mappable leaf setup.
|
||||
/// \param[in] node The leaf node performing setup work.
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
|
||||
|
||||
/// \brief Assigns the leaf and cache operators that are involved in a cache transformation
|
||||
|
@ -191,7 +191,7 @@ class CacheTransformPass : public TreePass {
|
|||
/// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations
|
||||
/// \param[inout] tree The tree to operate on.
|
||||
/// \param[inout] Indicate of the tree was modified.
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
|
||||
|
||||
private:
|
||||
|
@ -212,7 +212,7 @@ class CacheTransformPass : public TreePass {
|
|||
/// \param[in] leaf_op The leaf node in the transform
|
||||
/// \param[in] cache_op The cache op in the transform (will get removed)
|
||||
/// \param[in] cache_client The cache client
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op,
|
||||
std::shared_ptr<DatasetOp> cache_op, std::shared_ptr<CacheClient> cache_client);
|
||||
};
|
||||
|
|
|
@ -38,61 +38,61 @@ class CacheValidationPass : public IRNodePass {
|
|||
/// \brief Returns an error if BatchNode exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<BatchNode> node, bool *modified) override;
|
||||
|
||||
/// \brief Returns an error if ConcatNode exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<ConcatNode> node, bool *modified) override;
|
||||
|
||||
/// \brief Returns an error if FilterNode exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<FilterNode> node, bool *modified) override;
|
||||
|
||||
/// \brief Returns an error if SkipNode exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<SkipNode> node, bool *modified) override;
|
||||
|
||||
/// \brief Returns an error if TakeNode exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<TakeNode> node, bool *modified) override;
|
||||
|
||||
/// \brief Returns an error if ZipNode exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<ZipNode> node, bool *modified) override;
|
||||
|
||||
/// \brief Returns an error if MapNode with non-deterministic tensor operations exists under a cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<MapNode> node, bool *modified) override;
|
||||
|
||||
/// \brief Returns an error if there is a cache over another cache
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override;
|
||||
|
||||
/// \brief Identifies and block repeat under cache scenarios
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *modified) override;
|
||||
|
||||
/// \brief Identifies the subtree above this node as not being cached
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -45,27 +45,27 @@ class EpochCtrlPass : public IRTreePass {
|
|||
/// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<RootNode> node, bool *modified) override;
|
||||
|
||||
/// \brief Performs finder work for BuildVocabNode that has special rules about epoch control injection.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<BuildVocabNode> node, bool *modified) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Performs finder work for BuildSentenceVocabNode that has special rules about epoch control injection.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Register the TransferNode for further action.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status VisitAfter(std::shared_ptr<TransferNode> node, bool *modified) override;
|
||||
|
||||
/// \brief Getter
|
||||
|
@ -89,7 +89,7 @@ class EpochCtrlPass : public IRTreePass {
|
|||
/// \brief Runs an injection pass to inject in operators needed at the pre pass stage
|
||||
/// \param[inout] tree The tree to operate on.
|
||||
/// \param[inout] Indicate of the tree was modified.
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override;
|
||||
};
|
||||
} // namespace dataset
|
||||
|
|
|
@ -46,20 +46,20 @@ class EpochInjectionPass : public TreePass {
|
|||
/// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Register the DeviceQueueOp for further action.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Getter
|
||||
|
@ -79,7 +79,7 @@ class EpochInjectionPass : public TreePass {
|
|||
/// \brief Runs an injection pass to inject in operators needed at the pre pass stage
|
||||
/// \param[inout] tree The tree to operate on.
|
||||
/// \param[inout] Indicate of the tree was modified.
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
|
||||
};
|
||||
} // namespace dataset
|
||||
|
|
|
@ -17,7 +17,10 @@
|
|||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "minddata/dataset/engine/opt/pre/node_removal_pass.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/repeat_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -47,7 +50,16 @@ Status NodeRemovalPass::RemovalNodes::VisitAfter(std::shared_ptr<DatasetNode> no
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Perform ShuffleOp removal check.
|
||||
// Perform RepeatNode removal check.
|
||||
Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<RepeatNode> node, bool *modified) {
|
||||
*modified = false;
|
||||
if (node->Count() == 1) {
|
||||
nodes_to_remove_.push_back(std::static_pointer_cast<DatasetNode>(node));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Perform ShuffleNode removal check.
|
||||
Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<ShuffleNode> node, bool *modified) {
|
||||
*modified = false;
|
||||
#if 0
|
||||
|
@ -60,6 +72,24 @@ Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<ShuffleNode> node, b
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Perform SkipNode removal check.
|
||||
Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<SkipNode> node, bool *modified) {
|
||||
*modified = false;
|
||||
if (node->Count() == 0) {
|
||||
nodes_to_remove_.push_back(std::static_pointer_cast<DatasetNode>(node));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Perform TakeNode removal check.
|
||||
Status NodeRemovalPass::RemovalNodes::Visit(std::shared_ptr<TakeNode> node, bool *modified) {
|
||||
*modified = false;
|
||||
if (node->Count() == -1) {
|
||||
nodes_to_remove_.push_back(std::static_pointer_cast<DatasetNode>(node));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// constructor
|
||||
NodeRemovalPass::NodeRemovalPass() {}
|
||||
|
||||
|
|
|
@ -45,21 +45,39 @@ class NodeRemovalPass : public IRTreePass {
|
|||
/// \brief Identifies the subtree below this node as a cached descendant tree.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override;
|
||||
|
||||
/// \brief Resets the tracking of the cache within the tree
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status VisitAfter(std::shared_ptr<DatasetNode> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform RepeatNode removal check
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<RepeatNode> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform ShuffleNode removal check
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<ShuffleNode> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform SkipNode removal check
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<SkipNode> node, bool *modified) override;
|
||||
|
||||
/// \brief Perform TakeNode removal check
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<TakeNode> node, bool *modified) override;
|
||||
|
||||
/// \brief Getter
|
||||
/// \return All the nodes to be removed
|
||||
std::vector<std::shared_ptr<DatasetNode>> nodes_to_remove() { return nodes_to_remove_; }
|
||||
|
@ -79,7 +97,7 @@ class NodeRemovalPass : public IRTreePass {
|
|||
/// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
|
||||
/// \param[inout] tree The tree to operate on.
|
||||
/// \param[inout] Indicate of the tree was modified.
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) override;
|
||||
};
|
||||
} // namespace dataset
|
||||
|
|
|
@ -46,20 +46,20 @@ class RemovalPass : public TreePass {
|
|||
/// \brief Identifies the subtree below this node as a cached descendant tree.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Resets the tracking of the cache within the tree
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Perform ShuffleOp removal check
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
|
||||
|
||||
/// \brief Getter
|
||||
|
@ -81,7 +81,7 @@ class RemovalPass : public TreePass {
|
|||
/// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
|
||||
/// \param[inout] tree The tree to operate on.
|
||||
/// \param[inout] Indicate of the tree was modified.
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
|
||||
};
|
||||
} // namespace dataset
|
||||
|
|
|
@ -53,7 +53,7 @@ class ConnectorSize : public Sampling {
|
|||
std::string Name() const override { return kConnectorSizeSamplingName; }
|
||||
|
||||
// Save sampling data to file
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status SaveToFile() override;
|
||||
|
||||
Status Init(const std::string &dir_path, const std::string &device_id) override;
|
||||
|
|
|
@ -65,7 +65,7 @@ class ConnectorThroughput : public Sampling {
|
|||
std::string Name() const override { return name_; };
|
||||
|
||||
// Save sampling data to file
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status SaveToFile() override;
|
||||
|
||||
Status Init(const std::string &dir_path, const std::string &device_id);
|
||||
|
|
|
@ -32,13 +32,13 @@ class DatasetIteratorTracing : public Tracing {
|
|||
~DatasetIteratorTracing() override = default;
|
||||
|
||||
// Record tracing data
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value);
|
||||
|
||||
std::string Name() const override { return kDatasetIteratorTracingName; };
|
||||
|
||||
// Save tracing data to file
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status SaveToFile() override;
|
||||
|
||||
Status Init(const std::string &dir_path, const std::string &device_id) override;
|
||||
|
|
|
@ -32,13 +32,13 @@ class DeviceQueueTracing : public Tracing {
|
|||
~DeviceQueueTracing() override = default;
|
||||
|
||||
// Record tracing data
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value);
|
||||
|
||||
std::string Name() const override { return kDeviceQueueTracingName; };
|
||||
|
||||
// Save tracing data to file
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status SaveToFile() override;
|
||||
|
||||
Status Init(const std::string &dir_path, const std::string &device_id) override;
|
||||
|
|
|
@ -87,19 +87,19 @@ class ProfilingManager {
|
|||
Status Initialize();
|
||||
|
||||
// Save profile data to file
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status SaveProfilingData();
|
||||
|
||||
// Sampling node getter
|
||||
// @param name - The name of the requested node
|
||||
// @param node - Pointer to the shared pointer for the Sampling node
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetSamplingNode(const std::string &name, std::shared_ptr<Sampling> *node);
|
||||
|
||||
// Tracing node getter
|
||||
// @param name - The name of the requested node
|
||||
// @param node - Pointer to the shared pointer for the Tracing node
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status GetTracingNode(const std::string &name, std::shared_ptr<Tracing> *node);
|
||||
|
||||
// If profiling is enabled.
|
||||
|
@ -120,12 +120,12 @@ class ProfilingManager {
|
|||
|
||||
// Register profile node to tree
|
||||
// @param node - Profiling node
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status RegisterTracingNode(std::shared_ptr<Tracing> node);
|
||||
|
||||
// Register profile node to tree
|
||||
// @param node - Profiling node
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status RegisterSamplingNode(std::shared_ptr<Sampling> node);
|
||||
|
||||
ExecutionTree *tree_ = nullptr; // ExecutionTree pointer
|
||||
|
|
|
@ -442,7 +442,7 @@ class SchemaObj {
|
|||
Status parse_column(nlohmann::json columns);
|
||||
|
||||
/// \brief Get schema file from JSON file
|
||||
/// \param[in] json_obj Object of JSON parsed.
|
||||
/// \param[in] json_obj parsed JSON object
|
||||
/// \return Status code
|
||||
Status from_json(nlohmann::json json_obj);
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@ Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
|
|||
// @param std::shared_ptr<Tensor> *dst - return tensor padded
|
||||
// @param std::vector<dsize_t> pad_shape - shape to pad to
|
||||
// @param std::shared_ptr<Tensor> pad_val - value to pad with in Tensor format,
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
|
||||
const std::shared_ptr<Tensor> &pad_val);
|
||||
|
||||
|
@ -86,7 +86,7 @@ Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
|
|||
// @param std::shared_ptr<Tensor> *dst - return tensor padded
|
||||
// @param std::vector<dsize_t> pad_shape - shape to pad to
|
||||
// @param float pad_val - value to pad with
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
|
||||
const std::vector<dsize_t> &pad_shape, float pad_val);
|
||||
|
||||
|
@ -98,7 +98,7 @@ Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor>
|
|||
// @param std::vector<dsize_t> cur_ind - recursion helper
|
||||
// @param T pad_val - value to pad tensor with
|
||||
// @param size_t cur_dim - recursion helper
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst,
|
||||
std::vector<dsize_t> cur_ind, size_t cur_dim = 0);
|
||||
|
||||
|
@ -107,7 +107,7 @@ Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<T
|
|||
// @param std::shared_ptr<Tensor> *dst - return tensor padded
|
||||
// @param std::vector<dsize_t> pad_shape - shape to pad to
|
||||
// @param std::string pad_val - value to pad with
|
||||
// @return - The error code return
|
||||
// @return Status The status code returned
|
||||
Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
|
||||
const std::vector<dsize_t> &pad_shape, const std::string &pad_val);
|
||||
|
||||
|
@ -119,7 +119,7 @@ Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor>
|
|||
// @param std::vector<dsize_t> cur_ind - recursion helperas text
|
||||
// @param std::string pad_val - value to pad tensor with
|
||||
// @param size_t cur_dim - recursion helper
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::string> *dst,
|
||||
const TensorShape &dst_shape, std::vector<dsize_t> cur_ind, size_t cur_dim,
|
||||
const std::string &pad_value);
|
||||
|
|
|
@ -36,7 +36,7 @@ class ToFloat16Op : public TensorOp {
|
|||
// Overrides the base class compute function
|
||||
// Calls the ToFloat16 function in ImageUtils, this function takes an input tensor
|
||||
// and transforms its data to float16, the output memory is manipulated to contain the result
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||
|
|
|
@ -58,7 +58,7 @@ class CutOutOp : public TensorOp {
|
|||
// Overrides the base class compute function
|
||||
// Calls the erase function in ImageUtils, this function takes an input tensor
|
||||
// and overwrites some of its data using openCV, the output memory is manipulated to contain the result
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kCutOutOp; }
|
||||
|
|
|
@ -50,7 +50,7 @@ class RandomColorAdjustOp : public TensorOp {
|
|||
// Overrides the base class compute function.
|
||||
// Calls multiple transform functions in ImageUtils, this function takes an input tensor.
|
||||
// and transforms its data using openCV, the output memory is manipulated to contain the result.
|
||||
// @return Status - The error code return.
|
||||
// @return Status The status code returned.
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kRandomColorAdjustOp; }
|
||||
|
|
|
@ -61,7 +61,7 @@ class RandomRotationOp : public TensorOp {
|
|||
// Overrides the base class compute function
|
||||
// Calls the rotate function in ImageUtils, this function takes an input tensor
|
||||
// and transforms its data using openCV, the output memory is manipulated to contain the result
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ class UniformAugOp : public TensorOp {
|
|||
void Print(std::ostream &out) const override { out << Name() << ":: number of ops " << num_ops_; }
|
||||
|
||||
// Overrides the base class compute function
|
||||
// @return Status - The error code return
|
||||
// @return Status The status code returned
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
std::string Name() const override { return kUniformAugOp; }
|
||||
|
|
|
@ -53,7 +53,7 @@ class DataHelper {
|
|||
/// \param key Key of field to write to
|
||||
/// \param value Value array to write to file
|
||||
/// \param out_file Optional input for output file path, will write to input file if not specified
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<std::string> &value,
|
||||
const std::string &out_file = "");
|
||||
|
||||
|
@ -62,7 +62,7 @@ class DataHelper {
|
|||
/// \param key Key of field to write to
|
||||
/// \param value Value array to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
template <typename T>
|
||||
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<T> &value,
|
||||
const std::string &out_file = "") {
|
||||
|
@ -99,7 +99,7 @@ class DataHelper {
|
|||
/// \param key Key of field to write to
|
||||
/// \param value Value to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
template <typename T>
|
||||
Status UpdateValue(const std::string &in_file, const std::string &key, const T &value,
|
||||
const std::string &out_file = "") {
|
||||
|
@ -134,7 +134,7 @@ class DataHelper {
|
|||
/// \brief Template function to write tensor to file
|
||||
/// \param[in] in_file File to write to
|
||||
/// \param[in] data Array of type T values
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
template <typename T>
|
||||
Status WriteBinFile(const std::string &in_file, const std::vector<T> &data) {
|
||||
try {
|
||||
|
@ -157,7 +157,7 @@ class DataHelper {
|
|||
/// \param[in] in_file File name to write to
|
||||
/// \param[in] data Pointer to data
|
||||
/// \param[in] length Length of values to write from pointer
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
template <typename T>
|
||||
Status WriteBinFile(const std::string &in_file, T *data, size_t length) {
|
||||
try {
|
||||
|
@ -188,7 +188,7 @@ class DataHelper {
|
|||
/// note This function will return okay even if key not found
|
||||
/// \param[in] in_file Json file to remove key from
|
||||
/// \param[in] key The key to remove
|
||||
/// \return Status The error code return
|
||||
/// \return Status The status code returned
|
||||
Status RemoveKey(const std::string &in_file, const std::string &key, const std::string &out_file = "");
|
||||
|
||||
/// \brief A print method typically used for debugging
|
||||
|
|
|
@ -669,8 +669,6 @@ class Dataset:
|
|||
>>> repeat_and_shuffle = data.repeat(50)
|
||||
>>> repeat_and_shuffle = repeat_and_shuffle.shuffle(10)
|
||||
"""
|
||||
if count == 1:
|
||||
return self
|
||||
return RepeatDataset(self, count)
|
||||
|
||||
@check_skip
|
||||
|
@ -717,8 +715,6 @@ class Dataset:
|
|||
>>> # Create a dataset where the dataset includes 50 elements.
|
||||
>>> data = data.take(50)
|
||||
"""
|
||||
if count == -1:
|
||||
return self
|
||||
return TakeDataset(self, count)
|
||||
|
||||
def _get_absolute_split_sizes(self, sizes):
|
||||
|
|
|
@ -1311,6 +1311,51 @@ TEST_F(MindDataTestPipeline, TestSkipDataset) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestSkipTakeRepeat) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipTakeRepeat.";
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 6));
|
||||
|
||||
// Create a Skip operation on ds
|
||||
int32_t count = 0;
|
||||
ds = ds->Skip(count);
|
||||
|
||||
// Create a Project operation on ds
|
||||
std::vector<std::string> column_project = {"image"};
|
||||
ds = ds->Project(column_project);
|
||||
|
||||
// Add a Take(-1)
|
||||
ds = ds->Take(-1);
|
||||
|
||||
// Add a Repeat(1)
|
||||
ds = ds->Repeat(1);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
|
||||
// iterate over the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
MS_LOG(INFO) << "Number of rows: " << i;
|
||||
|
||||
// Expect 6 rows
|
||||
EXPECT_EQ(i, 6);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestSkipGetDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSkipGetDatasetSize.";
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
||||
|
@ -163,7 +163,7 @@ def test_take_08():
|
|||
|
||||
def test_take_09():
|
||||
"""
|
||||
Test take: repeat count is -1, and read the whole dataset, take after repeat
|
||||
Test take: take count is -1, and read the whole dataset, take after repeat
|
||||
"""
|
||||
logger.info("test_take_09")
|
||||
data1 = ds.GeneratorDataset(generator, ["data"])
|
||||
|
@ -180,7 +180,7 @@ def test_take_09():
|
|||
|
||||
def test_take_10():
|
||||
"""
|
||||
Test take: repeat count is -1, and read the whole dataset, take before repeat
|
||||
Test take: take count is -1, and read the whole dataset, take before repeat
|
||||
"""
|
||||
logger.info("test_take_10")
|
||||
data1 = ds.GeneratorDataset(generator, ["data"])
|
||||
|
@ -341,6 +341,18 @@ def test_take_18():
|
|||
assert sum([1 for _ in data1]) == 2
|
||||
|
||||
|
||||
def test_take_19():
|
||||
"""
|
||||
Test take: take is after batch, that mean take(N), N refer to batches num
|
||||
"""
|
||||
logger.info("test_take_19")
|
||||
with pytest.raises(ValueError) as info:
|
||||
data1 = ds.GeneratorDataset(generator, ["data"])
|
||||
|
||||
data1 = data1.batch(2)
|
||||
data1 = data1.take(0)
|
||||
assert "positive integer" in str(info.value)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_take_01()
|
||||
test_take_02()
|
||||
|
@ -360,4 +372,5 @@ if __name__ == '__main__':
|
|||
test_take_16()
|
||||
test_take_17()
|
||||
test_take_18()
|
||||
test_take_19()
|
||||
logger.info('== test take operation finished ==')
|
||||
|
|
Loading…
Reference in New Issue