forked from mindspore-Ecosystem/mindspore
Add SyncWaitNode class
This commit is contained in:
parent
a6075cc73b
commit
3108595885
|
@ -13,6 +13,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
|
|||
repeat_node.cc
|
||||
shuffle_node.cc
|
||||
skip_node.cc
|
||||
sync_wait_node.cc
|
||||
take_node.cc
|
||||
zip_node.cc
|
||||
)
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/barrier_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
// Constructor for SyncWaitNode
|
||||
SyncWaitNode::SyncWaitNode(std::shared_ptr<Dataset> child, const std::string &condition_name, int32_t num_batch,
|
||||
py::function callback)
|
||||
: condition_name_(condition_name), num_batch_(num_batch), callback_(callback) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
// Function to build the BarrierOp
|
||||
std::vector<std::shared_ptr<DatasetOp>> SyncWaitNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<BarrierOp>(num_batch_, connector_que_size_, condition_name_, callback_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Function to validate the parameters for SyncWaitNode
|
||||
Status SyncWaitNode::ValidateParams() {
|
||||
if (num_batch_ <= 0) {
|
||||
std::string err_msg = "SyncWaitNode: num_batch must be greater than 0, num_batch: " + std::to_string(num_batch_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (condition_name_.empty()) {
|
||||
std::string err_msg = "SyncWaitNode: condition_name must not be empty.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SYNC_WAIT_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SYNC_WAIT_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace api {
|
||||
|
||||
/// \class SyncWaitNode
|
||||
/// \brief A Dataset derived class to represent SyncWaitNode dataset
|
||||
class SyncWaitNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit SyncWaitNode(std::shared_ptr<Dataset> child, const std::string &condition_name, int32_t num_batch,
|
||||
py::function callback);
|
||||
|
||||
/// \brief Destructor
|
||||
~SyncWaitNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string condition_name_;
|
||||
int32_t num_batch_;
|
||||
py::function callback_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SYNC_WAIT_NODE_H_
|
|
@ -140,6 +140,7 @@ if (BUILD_MINDDATA STREQUAL "full")
|
|||
list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SRC_FILES
|
||||
"${MINDDATA_DIR}/engine/ir/datasetops/bucket_batch_by_length_node.cc"
|
||||
"${MINDDATA_DIR}/engine/ir/datasetops/build_vocab_node.cc"
|
||||
"${MINDDATA_DIR}/engine/ir/datasetops/sync_wait_node.cc"
|
||||
)
|
||||
list(REMOVE_ITEM MINDDATA_KERNELS_DATA_SRC_FILES
|
||||
"${MINDDATA_DIR}/kernels/data/unique_op.cc"
|
||||
|
|
Loading…
Reference in New Issue