add callback to map's c api

fix test err

fix ci round 1

fix ci round 2

add col_order to batch_node

minor fix
This commit is contained in:
Zirui Wu 2020-11-06 13:47:18 -05:00
parent a9a9c45662
commit 7538f82da9
13 changed files with 94 additions and 39 deletions

View File

@ -499,9 +499,10 @@ FilterDataset::FilterDataset(std::shared_ptr<Dataset> input, std::function<Tenso
MapDataset::MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations, MapDataset::MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns, std::vector<std::string> output_columns, std::vector<std::string> input_columns, std::vector<std::string> output_columns,
const std::vector<std::string> &project_columns, const std::shared_ptr<DatasetCache> &cache) { const std::vector<std::string> &project_columns, const std::shared_ptr<DatasetCache> &cache,
auto ds = std::vector<std::shared_ptr<DSCallback>> callbacks) {
std::make_shared<MapNode>(input->IRNode(), operations, input_columns, output_columns, project_columns, cache); auto ds = std::make_shared<MapNode>(input->IRNode(), operations, input_columns, output_columns, project_columns,
cache, callbacks);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }

View File

@ -44,7 +44,12 @@ bool Iterator::GetNextRow(TensorVec *row) {
} }
// Shut down the data pipeline. // Shut down the data pipeline.
void Iterator::Stop() { runtime_context_->Terminate(); } void Iterator::Stop() {
Status rc = runtime_context_->Terminate();
if (rc.IsError()) {
MS_LOG(ERROR) << rc.ToString();
}
}
// Function to build and launch the execution tree. // Function to build and launch the execution tree.
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) { Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {

View File

@ -385,6 +385,9 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return Status /// \return Status
virtual Status WaitForWorkers() { return Status::OK(); } virtual Status WaitForWorkers() { return Status::OK(); }
/// \brief Add callback to DatasetOp, only MapOp supports Callback at the moment
void AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) { callback_manager_.AddCallbacks(callbacks); }
protected: protected:
/// \brief Removes a parent operator from this operator /// \brief Removes a parent operator from this operator
/// \notes External callers do not have access to this function /// \notes External callers do not have access to this function

View File

@ -13,9 +13,9 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
#include <algorithm> #include <algorithm>
#include <cstring> #include <cstring>
#include <iostream>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/config_manager.h"
@ -26,8 +26,6 @@
#include "minddata/dataset/engine/data_buffer.h" #include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/datasetops/map_op/cpu_map_job.h" #include "minddata/dataset/engine/datasetops/map_op/cpu_map_job.h"
#include "minddata/dataset/engine/datasetops/map_op/gpu_map_job.h" #include "minddata/dataset/engine/datasetops/map_op/gpu_map_job.h"
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/kernels/tensor_op.h" #include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/task_manager.h" #include "minddata/dataset/util/task_manager.h"
@ -60,7 +58,7 @@ Status MapOp::Builder::Build(std::shared_ptr<MapOp> *ptr) {
RETURN_IF_NOT_OK(sanityCheck()); RETURN_IF_NOT_OK(sanityCheck());
*ptr = std::make_shared<MapOp>(std::move(build_in_col_names_), std::move(build_out_col_names_), *ptr = std::make_shared<MapOp>(std::move(build_in_col_names_), std::move(build_out_col_names_),
std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_); std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_);
(*ptr)->callback_manager_.AddCallbacks(std::move(builder_callbacks_)); (*ptr)->AddCallbacks(std::move(builder_callbacks_));
return Status::OK(); return Status::OK();
} }

View File

@ -31,13 +31,15 @@ namespace dataset {
// constructor #1, called by Pybind // constructor #1, called by Pybind
BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder, bool pad, BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder, bool pad,
const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names, const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names,
py::function batch_size_func, py::function batch_map_func, const std::vector<std::string> &col_order, py::function batch_size_func,
py::function batch_map_func,
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map) std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map)
: batch_size_(batch_size), : batch_size_(batch_size),
drop_remainder_(drop_remainder), drop_remainder_(drop_remainder),
pad_(pad), pad_(pad),
in_col_names_(in_col_names), in_col_names_(in_col_names),
out_col_names_(out_col_names), out_col_names_(out_col_names),
col_order_(col_order),
batch_size_func_(batch_size_func), batch_size_func_(batch_size_func),
batch_map_func_(batch_map_func), batch_map_func_(batch_map_func),
pad_map_(pad_map) { pad_map_(pad_map) {
@ -83,8 +85,8 @@ std::vector<std::shared_ptr<DatasetOp>> BatchNode::Build() {
in_col_names_, out_col_names_, batch_size_func_, batch_map_func_, in_col_names_, out_col_names_, batch_size_func_, batch_map_func_,
pad_map_)); pad_map_));
// need to insert a project when per_batch_func changes the number of columns // need to insert a project when per_batch_func changes the number of columns
if (!out_col_names_.empty()) { if (!col_order_.empty()) {
auto project_op = std::make_shared<ProjectOp>(out_col_names_); auto project_op = std::make_shared<ProjectOp>(col_order_);
node_ops.push_back(project_op); node_ops.push_back(project_op);
} }
#else #else

View File

@ -34,7 +34,7 @@ class BatchNode : public DatasetNode {
/// \brief Constructor #1, for Python API to create a BatchNode /// \brief Constructor #1, for Python API to create a BatchNode
BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder, bool pad, BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder, bool pad,
const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names, const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names,
py::function batch_size_func, py::function batch_map_func, const std::vector<std::string> &col_order, py::function batch_size_func, py::function batch_map_func,
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map); std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map);
#endif #endif
@ -58,6 +58,7 @@ class BatchNode : public DatasetNode {
bool pad_; bool pad_;
std::vector<std::string> in_col_names_; std::vector<std::string> in_col_names_;
std::vector<std::string> out_col_names_; std::vector<std::string> out_col_names_;
std::vector<std::string> col_order_;
#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
py::function batch_size_func_; py::function batch_size_func_;
py::function batch_map_func_; py::function batch_map_func_;

View File

@ -29,12 +29,14 @@ namespace dataset {
MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations, MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns, std::vector<std::string> output_columns, std::vector<std::string> input_columns, std::vector<std::string> output_columns,
const std::vector<std::string> &project_columns, std::shared_ptr<DatasetCache> cache) const std::vector<std::string> &project_columns, std::shared_ptr<DatasetCache> cache,
std::vector<std::shared_ptr<DSCallback>> callbacks)
: operations_(operations), : operations_(operations),
input_columns_(input_columns), input_columns_(input_columns),
output_columns_(output_columns), output_columns_(output_columns),
project_columns_(project_columns), project_columns_(project_columns),
DatasetNode(std::move(cache)) { DatasetNode(std::move(cache)),
callbacks_(callbacks) {
this->children.push_back(child); this->children.push_back(child);
} }
@ -53,6 +55,11 @@ std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() {
// This parameter will be removed with next rebase // This parameter will be removed with next rebase
std::vector<std::string> col_orders; std::vector<std::string> col_orders;
auto map_op = std::make_shared<MapOp>(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_); auto map_op = std::make_shared<MapOp>(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_);
if (!callbacks_.empty()) {
map_op->AddCallbacks(callbacks_);
}
if (!project_columns_.empty()) { if (!project_columns_.empty()) {
auto project_op = std::make_shared<ProjectOp>(project_columns_); auto project_op = std::make_shared<ProjectOp>(project_columns_);
node_ops.push_back(project_op); node_ops.push_back(project_op);

View File

@ -31,7 +31,8 @@ class MapNode : public DatasetNode {
/// \brief Constructor /// \brief Constructor
MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations, MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {}, std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {},
const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr); const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr,
std::vector<std::shared_ptr<DSCallback>> callbacks = {});
/// \brief Destructor /// \brief Destructor
~MapNode() = default; ~MapNode() = default;
@ -49,6 +50,7 @@ class MapNode : public DatasetNode {
std::vector<std::string> input_columns_; std::vector<std::string> input_columns_;
std::vector<std::string> output_columns_; std::vector<std::string> output_columns_;
std::vector<std::string> project_columns_; std::vector<std::string> project_columns_;
std::vector<std::shared_ptr<DSCallback>> callbacks_;
}; };
} // namespace dataset } // namespace dataset

View File

@ -149,7 +149,7 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// We finish the walk of this RepeatOp's descendent nodes. // We finish the walk of this RepeatOp's descendent nodes.
// The total repeats of nodes above this Repeat(n) have nothing to do with this RepeatOp's parameter n. // The total repeats of nodes above this Repeat(n) have nothing to do with this RepeatOp's parameter n.
// But num_repeats_ has been multiplied by n during this Repeat(n)'s PreRunOnNode, // But num_repeats_ has been multiplied by n during this Repeat(n)'s PreRunOnNode,
// so we devide num_repeats_ by n to be able to correctly set total repeats for nodes above this RepeatOp. // so we divide num_repeats_ by n to be able to correctly set total repeats for nodes above this RepeatOp.
num_repeats_ /= node->num_repeats(); num_repeats_ /= node->num_repeats();
return Status::OK(); return Status::OK();
} }

View File

@ -120,7 +120,7 @@ Status CacheTransformPass::CachePass::NonMappableCacheLeafSetup(std::shared_ptr<
RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache.");
} }
// Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf // Sampler for non mappable dataset only works if there is a downstream cache. Remove it from the leaf
// as save it for use by cache op in ascendant tree. // as save it for use by cache op in ascendant tree.
if (is_caching_) { if (is_caching_) {
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_)); RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_));
@ -261,7 +261,8 @@ Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *modified) {
// Then, execute the transform for each pair // Then, execute the transform for each pair
for (auto cache_pair : cache_pass.cache_pairs()) { for (auto cache_pair : cache_pass.cache_pairs()) {
MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform."; MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform.";
ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client()); RETURN_IF_NOT_OK(
ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client()));
} }
MS_LOG(INFO) << "Pre pass: Cache transform pass complete."; MS_LOG(INFO) << "Pre pass: Cache transform pass complete.";
return Status::OK(); return Status::OK();

View File

@ -60,95 +60,95 @@ class CacheTransformPass : public TreePass {
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
/// \brief Perform leaf node cache tranform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return /// \return Status The error code return
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return /// \return Status The error code return
Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return /// \return Status The error code return
Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return /// \return Status The error code return
Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) override;
#endif #endif
/// \brief Perform leaf node cache tranform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return /// \return Status The error code return
Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return /// \return Status The error code return
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return /// \return Status The error code return
Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return /// \return Status The error code return
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override;
#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
/// \brief Perform leaf node cache tranform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return /// \return Status The error code return
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return /// \return Status The error code return
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return /// \return Status The error code return
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override;
#endif #endif
/// \brief Perform leaf node cache tranform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return /// \return Status The error code return
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return /// \return Status The error code return
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return /// \return Status The error code return
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override;
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
/// \brief Perform leaf node cache tranform identifications /// \brief Perform leaf node cache transform identifications
/// \param[in] node The node being visited /// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return /// \return Status The error code return

View File

@ -276,9 +276,10 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
std::vector<std::string> input_columns = {}, std::vector<std::string> input_columns = {},
std::vector<std::string> output_columns = {}, std::vector<std::string> output_columns = {},
const std::vector<std::string> &project_columns = {}, const std::vector<std::string> &project_columns = {},
const std::shared_ptr<DatasetCache> &cache = nullptr) { const std::shared_ptr<DatasetCache> &cache = nullptr,
std::vector<std::shared_ptr<DSCallback>> callbacks = {}) {
return std::make_shared<MapDataset>(shared_from_this(), operations, input_columns, output_columns, project_columns, return std::make_shared<MapDataset>(shared_from_this(), operations, input_columns, output_columns, project_columns,
cache); cache, callbacks);
} }
/// \brief Function to create a Project Dataset /// \brief Function to create a Project Dataset
@ -443,7 +444,8 @@ class MapDataset : public Dataset {
public: public:
MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations, MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns, std::vector<std::string> output_columns, std::vector<std::string> input_columns, std::vector<std::string> output_columns,
const std::vector<std::string> &project_columns, const std::shared_ptr<DatasetCache> &cache); const std::vector<std::string> &project_columns, const std::shared_ptr<DatasetCache> &cache,
std::vector<std::shared_ptr<DSCallback>> callbacks);
}; };
class ProjectDataset : public Dataset { class ProjectDataset : public Dataset {

View File

@ -21,6 +21,8 @@
#include "minddata/dataset/callback/ds_callback.h" #include "minddata/dataset/callback/ds_callback.h"
#include "minddata/dataset/core/client.h" #include "minddata/dataset/core/client.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h" #include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/kernels/data/no_op.h" #include "minddata/dataset/kernels/data/no_op.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
@ -149,7 +151,7 @@ TEST_F(MindDataTestCallback, TestBasicCallback) {
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
TensorShape shape({}); // empty shape is a 1-value scalar Tensor TensorShape shape({}); // empty shape is a 1-value scalar Tensor
ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape); ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape);
schema->AddColumn(col); ASSERT_OK(schema->AddColumn(col));
std::shared_ptr<RandomDataOp> leaf; std::shared_ptr<RandomDataOp> leaf;
rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(44).Build(&leaf); rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(44).Build(&leaf);
EXPECT_TRUE(rc.IsOk()); EXPECT_TRUE(rc.IsOk());
@ -196,7 +198,7 @@ TEST_F(MindDataTestCallback, TestMutiEpochCallback) {
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
TensorShape shape({}); // empty shape is a 1-value scalar Tensor TensorShape shape({}); // empty shape is a 1-value scalar Tensor
ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape); ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape);
schema->AddColumn(col); ASSERT_OK(schema->AddColumn(col));
std::shared_ptr<RandomDataOp> leaf; std::shared_ptr<RandomDataOp> leaf;
rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf); rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf);
EXPECT_TRUE(rc.IsOk()); EXPECT_TRUE(rc.IsOk());
@ -253,7 +255,7 @@ TEST_F(MindDataTestCallback, TestSelectedCallback) {
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
TensorShape shape({}); // empty shape is a 1-value scalar Tensor TensorShape shape({}); // empty shape is a 1-value scalar Tensor
ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape); ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape);
schema->AddColumn(col); ASSERT_OK(schema->AddColumn(col));
std::shared_ptr<RandomDataOp> leaf; std::shared_ptr<RandomDataOp> leaf;
rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf); rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf);
EXPECT_TRUE(rc.IsOk()); EXPECT_TRUE(rc.IsOk());
@ -296,3 +298,34 @@ TEST_F(MindDataTestCallback, TestSelectedCallback) {
EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs); EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs);
EXPECT_EQ(tst_cb->all_step_nums(len), all_steps); EXPECT_EQ(tst_cb->all_step_nums(len), all_steps);
} }
TEST_F(MindDataTestCallback, TestCAPICallback) {
MS_LOG(INFO) << "Doing: MindDataTestCallback-TestCAPICallback";
// config callback
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(64);
std::shared_ptr<DSCallback> cb1 = tst_cb;
// config leaf_op, use random_data to avoid I/O
std::shared_ptr<SchemaObj> schema = std::make_shared<SchemaObj>();
ASSERT_TRUE(schema->add_column("label", "uint32", {}));
std::shared_ptr<Dataset> ds = RandomData(44, schema);
ds = ds->Map({transforms::TypeCast("uint64")}, {"label"}, {}, {}, nullptr, {cb1});
ds = ds->Repeat(2);
TreeAdapter tree_adapter;
// using tree_adapter to set num_epoch = 1
ASSERT_OK(tree_adapter.Compile(ds->IRNode(), 1));
TensorRow row;
ASSERT_OK(tree_adapter.GetNext(&row));
while (!row.empty()) {
ASSERT_OK(tree_adapter.GetNext(&row));
}
std::vector<std::string> callback_names = {"BGN", "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND"};
std::vector<int64_t> all_steps = {0, 0, 1, 1, 65, 65, 88};
std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 1, 1};
// doing resize to make sure no unexpected epoch_end or extra epoch_begin is called
size_t len = 7;
EXPECT_EQ(tst_cb->all_names(len), callback_names);
EXPECT_EQ(tst_cb->all_step_nums(len), all_steps);
EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs);
}