forked from mindspore-Ecosystem/mindspore
!7593 C++ api add DeviceQueueOp
Merge pull request !7593 from xiaotianci/device_op
This commit is contained in:
commit
95fe324798
|
@ -62,6 +62,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
|
@ -72,6 +73,7 @@
|
|||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/dataset/util/services.h"
|
||||
|
||||
// IR leaf nodes
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
|
||||
|
@ -125,6 +127,56 @@ std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> colum
|
|||
return iter;
|
||||
}
|
||||
|
||||
// Function to return a transferred Node that transfers data through a device.
|
||||
bool Dataset::DeviceQueue(bool send_epoch_end) {
|
||||
Status rc;
|
||||
|
||||
// Build and launch tree
|
||||
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
|
||||
rc = runtime_context->Init();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get a uuid for queue name
|
||||
std::string queue_name = Services::GetUniqueID();
|
||||
|
||||
// TODO(CRC):
|
||||
// Get device type from ms context
|
||||
std::string device_type = "CPU";
|
||||
|
||||
// Get device ID from children
|
||||
int32_t device_id = 0;
|
||||
rc = TransferNode::get_distribution(shared_from_this(), &device_id);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "Failed to get shard id. Error status: " << rc;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Add TransferNode IR on top of dataset d
|
||||
auto ds = std::make_shared<TransferNode>(shared_from_this(), queue_name, device_id, device_type, send_epoch_end);
|
||||
|
||||
// Get ToDevice consumer
|
||||
auto consumer = std::make_unique<ToDevice>(device_type, send_epoch_end, -1);
|
||||
ToDevice *consumer_ = consumer.get();
|
||||
rc = consumer->Init(ds);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "ToDevice: Failed to init. Error status: " << rc;
|
||||
return false;
|
||||
}
|
||||
runtime_context->AssignConsumer(std::move(consumer));
|
||||
|
||||
// Send data to device
|
||||
rc = consumer_->Send();
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "ToDevice: Failed to send data to device. Error status: " << rc;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// Function to create the saver, which will build and launch the execution tree and save data
|
||||
bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string dataset_type) {
|
||||
|
@ -931,6 +983,7 @@ std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t me
|
|||
auto cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz);
|
||||
return cache->ValidateParams() ? cache : nullptr;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace api
|
||||
|
|
|
@ -74,13 +74,31 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr>
|
|||
|
||||
// ToDevice
|
||||
Status ToDevice::Init(std::shared_ptr<api::Dataset> d) {
|
||||
// TODO(CRC):
|
||||
// Get device ID from children look at get_distribution in python
|
||||
// Add DeviceQue IR on top of dataset d
|
||||
|
||||
return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_);
|
||||
}
|
||||
|
||||
Status ToDevice::Send() {
|
||||
std::unique_ptr<DataBuffer> db;
|
||||
RETURN_IF_NOT_OK(tree_adapter_->Launch());
|
||||
RETURN_IF_NOT_OK(tree_adapter_->root()->GetNextBuffer(&db));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ToDevice::Continue() {
|
||||
// tree_.root() must be DeviceQueueOp
|
||||
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_adapter_->root().get());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "ContinueSend only supported by DeviceQueueOp");
|
||||
op->ContinueSend();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ToDevice::Stop() {
|
||||
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_adapter_->root().get());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp");
|
||||
op->StopSend();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// SaveToDisk
|
||||
Status SaveToDisk::ValidateParams() {
|
||||
|
|
|
@ -126,23 +126,27 @@ class SaveToDisk : public TreeConsumer {
|
|||
/// Consumer that iterates over the dataset and send it to a device
|
||||
class ToDevice : public TreeConsumer {
|
||||
public:
|
||||
ToDevice(std::string device_type, bool send_epoch_end, int32_t num_epochs)
|
||||
ToDevice(std::string device_type, bool send_epoch_end, int32_t num_epochs = -1)
|
||||
: TreeConsumer(), device_type_(device_type), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {}
|
||||
|
||||
Status Init(std::shared_ptr<api::Dataset> d) override;
|
||||
|
||||
Status Send() {
|
||||
// TODO(CRC): launch the tree
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
Status Stop() {
|
||||
// TODO(CRC): Get root + call StopSend
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
Status Continue() {
|
||||
// TODO(CRC): Get root + call StopSend
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
/// Send the data to device
|
||||
/// \return Status error code
|
||||
Status Send();
|
||||
|
||||
/// Stop to send data to device
|
||||
/// \return Status error code
|
||||
Status Stop();
|
||||
|
||||
/// Continue to send data to device
|
||||
/// \return Status error code
|
||||
Status Continue();
|
||||
|
||||
protected:
|
||||
/// Method to return the name of the consumer
|
||||
/// \return string
|
||||
std::string Name() override { return "ToDevice"; }
|
||||
|
||||
private:
|
||||
std::string device_type_;
|
||||
|
|
|
@ -15,6 +15,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
|
|||
skip_node.cc
|
||||
sync_wait_node.cc
|
||||
take_node.cc
|
||||
transfer_node.cc
|
||||
zip_node.cc
|
||||
)
|
||||
|
||||
|
|
|
@ -68,6 +68,13 @@ std::vector<std::shared_ptr<DatasetOp>> AlbumNode::Build() {
|
|||
return node_ops;
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status AlbumNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -44,6 +44,10 @@ class AlbumNode : public Dataset {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string schema_path_;
|
||||
|
|
|
@ -67,6 +67,13 @@ std::vector<std::shared_ptr<DatasetOp>> CelebANode::Build() {
|
|||
return node_ops;
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status CelebANode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -46,6 +46,10 @@ class CelebANode : public Dataset {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
|
|
|
@ -66,6 +66,13 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Node::Build() {
|
|||
return node_ops;
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status Cifar100Node::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -44,6 +44,10 @@ class Cifar100Node : public Dataset {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
|
|
|
@ -64,6 +64,13 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Node::Build() {
|
|||
return node_ops;
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status Cifar10Node::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -44,6 +44,10 @@ class Cifar10Node : public Dataset {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
|
|
|
@ -213,6 +213,13 @@ std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() {
|
|||
return node_ops;
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status CLUENode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = shard_id_;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -45,6 +45,10 @@ class CLUENode : public Dataset {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
private:
|
||||
/// \brief Split string based on a character delimiter
|
||||
/// \return A string vector
|
||||
|
|
|
@ -117,6 +117,14 @@ std::vector<std::shared_ptr<DatasetOp>> CocoNode::Build() {
|
|||
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status CocoNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -43,6 +43,10 @@ class CocoNode : public Dataset {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string annotation_file_;
|
||||
|
|
|
@ -122,6 +122,14 @@ std::vector<std::shared_ptr<DatasetOp>> CSVNode::Build() {
|
|||
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status CSVNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = shard_id_;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -66,6 +66,10 @@ class CSVNode : public Dataset {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> dataset_files_;
|
||||
char field_delim_;
|
||||
|
|
|
@ -70,6 +70,14 @@ std::vector<std::shared_ptr<DatasetOp>> ImageFolderNode::Build() {
|
|||
std::move(sampler_->Build())));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status ImageFolderNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -51,6 +51,10 @@ class ImageFolderNode : public Dataset {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
bool decode_;
|
||||
|
|
|
@ -85,6 +85,14 @@ std::vector<std::shared_ptr<DatasetOp>> ManifestNode::Build() {
|
|||
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status ManifestNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -44,6 +44,10 @@ class ManifestNode : public Dataset {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
private:
|
||||
std::string dataset_file_;
|
||||
std::string usage_;
|
||||
|
|
|
@ -160,6 +160,13 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() {
|
|||
return node_ops;
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status MindDataNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -48,6 +48,10 @@ class MindDataNode : public Dataset {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Build sampler chain for minddata dataset
|
||||
/// \return Status Status::OK() if input sampler is valid
|
||||
Status BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> &sampler,
|
||||
|
|
|
@ -60,6 +60,13 @@ std::vector<std::shared_ptr<DatasetOp>> MnistNode::Build() {
|
|||
return node_ops;
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status MnistNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -44,6 +44,10 @@ class MnistNode : public Dataset {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
|
|
|
@ -99,6 +99,13 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() {
|
|||
return node_ops;
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status RandomNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -65,6 +65,10 @@ class RandomNode : public Dataset {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
private:
|
||||
/// \brief A quick inline for producing a random number between (and including) min/max
|
||||
/// \param[in] min minimum number that can be generated.
|
||||
|
|
|
@ -95,6 +95,13 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() {
|
|||
return node_ops;
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status TextFileNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = shard_id_;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -45,6 +45,10 @@ class TextFileNode : public Dataset {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> dataset_files_;
|
||||
int32_t num_samples_;
|
||||
|
|
|
@ -80,6 +80,13 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() {
|
|||
return node_ops;
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status TFRecordNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = shard_id_;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -72,6 +72,10 @@ class TFRecordNode : public Dataset {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> dataset_files_;
|
||||
std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string
|
||||
|
|
|
@ -112,6 +112,13 @@ std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() {
|
|||
return node_ops;
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status VOCNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -45,6 +45,10 @@ class VOCNode : public Dataset {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
private:
|
||||
const std::string kColumnImage = "image";
|
||||
const std::string kColumnTarget = "target";
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
// Constructor for TransferNode
|
||||
TransferNode::TransferNode(std::shared_ptr<Dataset> child, const std::string &queue_name, int32_t device_id,
|
||||
const std::string &device_type, bool send_epoch_end)
|
||||
: queue_name_(queue_name),
|
||||
device_id_(device_id),
|
||||
device_type_(device_type),
|
||||
prefetch_size_(16),
|
||||
send_epoch_end_(send_epoch_end),
|
||||
total_batch_(0) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
// Validator for TransferNode
|
||||
Status TransferNode::ValidateParams() {
|
||||
// Check if device_type_ is in {"CPU", "GPU", "Ascend"}
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("TransferNode", device_type_, {"CPU", "GPU", "Ascend"}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build TransferNode
|
||||
std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
// Convert device_type_ from string to DeviceType
|
||||
DeviceQueueOp::DeviceType type;
|
||||
if (device_type_ == "CPU") {
|
||||
type = DeviceQueueOp::DeviceType::CPU;
|
||||
} else if (device_type_ == "GPU") {
|
||||
type = DeviceQueueOp::DeviceType::GPU;
|
||||
} else if (device_type_ == "Ascend") {
|
||||
type = DeviceQueueOp::DeviceType::Ascend;
|
||||
}
|
||||
|
||||
node_ops.push_back(
|
||||
std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_, total_batch_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Function to get the device_id
|
||||
Status TransferNode::get_distribution(std::shared_ptr<Dataset> ds, int32_t *device_id) {
|
||||
// Get device id according to the type of dataset
|
||||
Status rc = ds->GetShardId(device_id);
|
||||
if (rc != Status::OK()) {
|
||||
// Get device id from the child node
|
||||
if (ds->children.size()) {
|
||||
ds = ds->children[0];
|
||||
return TransferNode::get_distribution(ds, device_id);
|
||||
} else {
|
||||
std::string err_msg = "Unknown dataset type.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,62 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TRANSFER_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TRANSFER_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace api {
|
||||
|
||||
class TransferNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
TransferNode(std::shared_ptr<Dataset> child, const std::string &queue_name, int32_t device_id,
|
||||
const std::string &device_type, bool send_epoch_end);
|
||||
|
||||
/// \brief Destructor
|
||||
~TransferNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
static Status get_distribution(std::shared_ptr<Dataset> ds, int32_t *device_id);
|
||||
|
||||
private:
|
||||
std::string queue_name_;
|
||||
int32_t device_id_;
|
||||
std::string device_type_;
|
||||
int32_t prefetch_size_;
|
||||
bool send_epoch_end_;
|
||||
int32_t total_batch_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TRANSFER_NODE_H_
|
|
@ -57,6 +57,10 @@ class TreeAdapter {
|
|||
// to be able to launch a thread. BuildAndPrepare needs to be called before this function
|
||||
TaskGroup *AllTasks() const { return tree_ != nullptr ? tree_->AllTasks() : nullptr; }
|
||||
|
||||
std::shared_ptr<DatasetOp> root() { return tree_->root(); }
|
||||
|
||||
Status Launch() const { return tree_->Launch(); }
|
||||
|
||||
private:
|
||||
// This RECURSIVE function converts IR nodes into DatasetOp in ExecutionTree. IR could build a vector of ops. In
|
||||
// such case, the first node is returned. Op is added as child when the current function returns.
|
||||
|
|
|
@ -96,6 +96,7 @@ class RepeatNode;
|
|||
class ShuffleNode;
|
||||
class SkipNode;
|
||||
class TakeNode;
|
||||
class TransferNode;
|
||||
class ZipNode;
|
||||
|
||||
#define RETURN_EMPTY_IF_ERROR(_s) \
|
||||
|
@ -559,6 +560,7 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
public:
|
||||
// need friend class so they can access the children_ field
|
||||
friend class Iterator;
|
||||
friend class TransferNode;
|
||||
friend class mindspore::dataset::TreeAdapter;
|
||||
|
||||
/// \brief Constructor
|
||||
|
@ -579,6 +581,12 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
/// \return Status Status::OK() if all the parameters are valid
|
||||
virtual Status ValidateParams() = 0;
|
||||
|
||||
/// \brief Pure virtual function for derived class to get the shard id of specific node
|
||||
/// \return Status Status::OK() if get shard id successfully
|
||||
virtual Status GetShardId(int32_t *shard_id) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
|
||||
/// \brief Gets the dataset size
|
||||
/// \return status code
|
||||
int64_t GetDatasetSize();
|
||||
|
@ -617,6 +625,13 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
/// \return Shared pointer to the Iterator
|
||||
std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {});
|
||||
|
||||
/// \brief Function to transfer data through a device.
|
||||
/// \notes If device is Ascend, features of data will be transferred one by one. The limitation
|
||||
/// of data transmission per time is 256M.
|
||||
/// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=True).
|
||||
/// \return Returns true if no error encountered else false.
|
||||
bool DeviceQueue(bool send_epoch_end = true);
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline
|
||||
/// \note Usage restrictions:
|
||||
|
|
|
@ -17,8 +17,9 @@
|
|||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
|
||||
|
@ -48,6 +49,10 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
|
|||
/// \return Shared pointers to the newly created Sampler
|
||||
virtual std::shared_ptr<Sampler> Build() = 0;
|
||||
|
||||
/// \brief Function for derived class to get the shard id of sampler
|
||||
/// \return The shard id of the derived sampler
|
||||
virtual int64_t ShardId() { return 0; }
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
|
||||
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
|
||||
|
@ -134,6 +139,10 @@ class DistributedSamplerObj : public SamplerObj {
|
|||
|
||||
bool ValidateParams() override;
|
||||
|
||||
/// \brief Function to get the shard id of sampler
|
||||
/// \return The shard id of sampler
|
||||
int64_t ShardId() override { return shard_id_; }
|
||||
|
||||
private:
|
||||
int64_t num_shards_;
|
||||
int64_t shard_id_;
|
||||
|
|
Loading…
Reference in New Issue