From 31085958852013054417c2752cae96c36f97e2d1 Mon Sep 17 00:00:00 2001 From: shenwei41 Date: Wed, 21 Oct 2020 18:12:38 +0800 Subject: [PATCH] Add SyncWaitNode class --- .../engine/ir/datasetops/CMakeLists.txt | 1 + .../engine/ir/datasetops/sync_wait_node.cc | 63 +++++++++++++++++++ .../engine/ir/datasetops/sync_wait_node.h | 58 +++++++++++++++++ mindspore/lite/minddata/CMakeLists.txt | 1 + 4 files changed, 123 insertions(+) create mode 100644 mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt index 446b2195ed5..3f50c8d5d65 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt @@ -13,6 +13,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES repeat_node.cc shuffle_node.cc skip_node.cc + sync_wait_node.cc take_node.cc zip_node.cc ) diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc new file mode 100644 index 00000000000..c4c83903dd4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h" + +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/barrier_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +namespace api { +// Constructor for SyncWaitNode +SyncWaitNode::SyncWaitNode(std::shared_ptr child, const std::string &condition_name, int32_t num_batch, + py::function callback) + : condition_name_(condition_name), num_batch_(num_batch), callback_(callback) { + this->children.push_back(child); +} + +// Function to build the BarrierOp +std::vector> SyncWaitNode::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + node_ops.push_back(std::make_shared(num_batch_, connector_que_size_, condition_name_, callback_)); + return node_ops; +} + +// Function to validate the parameters for SyncWaitNode +Status SyncWaitNode::ValidateParams() { + if (num_batch_ <= 0) { + std::string err_msg = "SyncWaitNode: num_batch must be greater than 0, num_batch: " + std::to_string(num_batch_); + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + + if (condition_name_.empty()) { + std::string err_msg = "SyncWaitNode: condition_name must not be empty."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + + return Status::OK(); +} +} // namespace api +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h new file mode 100644 index 00000000000..3cd033b49ca --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SYNC_WAIT_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SYNC_WAIT_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/include/datasets.h" + +namespace mindspore { +namespace dataset { + +namespace api { + +/// \class SyncWaitNode +/// \brief A Dataset derived class to represent SyncWaitNode dataset +class SyncWaitNode : public Dataset { + public: + /// \brief Constructor + explicit SyncWaitNode(std::shared_ptr child, const std::string &condition_name, int32_t num_batch, + py::function callback); + + /// \brief Destructor + ~SyncWaitNode() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return The list of shared pointers to the newly created DatasetOps + std::vector> Build() override; + + /// \brief Parameters validation + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; + + private: + std::string condition_name_; + int32_t num_batch_; + py::function callback_; +}; +} // namespace api +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SYNC_WAIT_NODE_H_ diff --git a/mindspore/lite/minddata/CMakeLists.txt b/mindspore/lite/minddata/CMakeLists.txt index d6dc3e5cce2..a6168113678 100644 --- a/mindspore/lite/minddata/CMakeLists.txt +++ b/mindspore/lite/minddata/CMakeLists.txt @@ -140,6 +140,7 @@ if (BUILD_MINDDATA STREQUAL "full") list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SRC_FILES "${MINDDATA_DIR}/engine/ir/datasetops/bucket_batch_by_length_node.cc" "${MINDDATA_DIR}/engine/ir/datasetops/build_vocab_node.cc" + "${MINDDATA_DIR}/engine/ir/datasetops/sync_wait_node.cc" ) list(REMOVE_ITEM MINDDATA_KERNELS_DATA_SRC_FILES "${MINDDATA_DIR}/kernels/data/unique_op.cc"