Add SyncWaitNode class

This commit is contained in:
shenwei41 2020-10-21 18:12:38 +08:00
parent a6075cc73b
commit 3108595885
4 changed files with 123 additions and 0 deletions

View File

@ -13,6 +13,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
repeat_node.cc
shuffle_node.cc
skip_node.cc
sync_wait_node.cc
take_node.cc
zip_node.cc
)

View File

@ -0,0 +1,63 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h"
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/barrier_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {
// Constructor for SyncWaitNode
SyncWaitNode::SyncWaitNode(std::shared_ptr<Dataset> child, const std::string &condition_name, int32_t num_batch,
py::function callback)
: condition_name_(condition_name), num_batch_(num_batch), callback_(callback) {
this->children.push_back(child);
}
// Function to build the BarrierOp
std::vector<std::shared_ptr<DatasetOp>> SyncWaitNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
node_ops.push_back(std::make_shared<BarrierOp>(num_batch_, connector_que_size_, condition_name_, callback_));
return node_ops;
}
// Function to validate the parameters for SyncWaitNode
Status SyncWaitNode::ValidateParams() {
if (num_batch_ <= 0) {
std::string err_msg = "SyncWaitNode: num_batch must be greater than 0, num_batch: " + std::to_string(num_batch_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (condition_name_.empty()) {
std::string err_msg = "SyncWaitNode: condition_name must not be empty.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SYNC_WAIT_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SYNC_WAIT_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/include/datasets.h"
namespace mindspore {
namespace dataset {
namespace api {
/// \class SyncWaitNode
/// \brief A Dataset derived class to represent SyncWaitNode dataset
class SyncWaitNode : public Dataset {
public:
/// \brief Constructor
explicit SyncWaitNode(std::shared_ptr<Dataset> child, const std::string &condition_name, int32_t num_batch,
py::function callback);
/// \brief Destructor
~SyncWaitNode() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
std::string condition_name_;
int32_t num_batch_;
py::function callback_;
};
} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SYNC_WAIT_NODE_H_

View File

@ -140,6 +140,7 @@ if (BUILD_MINDDATA STREQUAL "full")
list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SRC_FILES
"${MINDDATA_DIR}/engine/ir/datasetops/bucket_batch_by_length_node.cc"
"${MINDDATA_DIR}/engine/ir/datasetops/build_vocab_node.cc"
"${MINDDATA_DIR}/engine/ir/datasetops/sync_wait_node.cc"
)
list(REMOVE_ITEM MINDDATA_KERNELS_DATA_SRC_FILES
"${MINDDATA_DIR}/kernels/data/unique_op.cc"