!7593 C++ api add DeviceQueueOp

Merge pull request !7593 from xiaotianci/device_op
This commit is contained in:
mindspore-ci-bot 2020-10-24 17:23:40 +08:00 committed by Gitee
commit 95fe324798
39 changed files with 443 additions and 18 deletions

View File

@ -62,6 +62,7 @@
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
#ifndef ENABLE_ANDROID
@ -72,6 +73,7 @@
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/services.h"
// IR leaf nodes
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
@ -125,6 +127,56 @@ std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> colum
return iter;
}
// Function to return a transferred Node that transfers data through a device.
bool Dataset::DeviceQueue(bool send_epoch_end) {
Status rc;
// Build and launch tree
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
return false;
}
// Get a uuid for queue name
std::string queue_name = Services::GetUniqueID();
// TODO(CRC):
// Get device type from ms context
std::string device_type = "CPU";
// Get device ID from children
int32_t device_id = 0;
rc = TransferNode::get_distribution(shared_from_this(), &device_id);
if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to get shard id. Error status: " << rc;
return false;
}
// Add TransferNode IR on top of dataset d
auto ds = std::make_shared<TransferNode>(shared_from_this(), queue_name, device_id, device_type, send_epoch_end);
// Get ToDevice consumer
auto consumer = std::make_unique<ToDevice>(device_type, send_epoch_end, -1);
ToDevice *consumer_ = consumer.get();
rc = consumer->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "ToDevice: Failed to init. Error status: " << rc;
return false;
}
runtime_context->AssignConsumer(std::move(consumer));
// Send data to device
rc = consumer_->Send();
if (rc.IsError()) {
MS_LOG(ERROR) << "ToDevice: Failed to send data to device. Error status: " << rc;
return false;
}
return true;
}
#ifndef ENABLE_ANDROID
// Function to create the saver, which will build and launch the execution tree and save data
bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string dataset_type) {
@ -931,6 +983,7 @@ std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t me
auto cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz);
return cache->ValidateParams() ? cache : nullptr;
}
#endif
} // namespace api

View File

@ -74,13 +74,31 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr>
// ToDevice
Status ToDevice::Init(std::shared_ptr<api::Dataset> d) {
// TODO(CRC):
// Get device ID from children look at get_distribution in python
// Add DeviceQue IR on top of dataset d
return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_);
}
Status ToDevice::Send() {
std::unique_ptr<DataBuffer> db;
RETURN_IF_NOT_OK(tree_adapter_->Launch());
RETURN_IF_NOT_OK(tree_adapter_->root()->GetNextBuffer(&db));
return Status::OK();
}
Status ToDevice::Continue() {
// tree_.root() must be DeviceQueueOp
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_adapter_->root().get());
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "ContinueSend only supported by DeviceQueueOp");
op->ContinueSend();
return Status::OK();
}
Status ToDevice::Stop() {
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_adapter_->root().get());
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp");
op->StopSend();
return Status::OK();
}
#ifndef ENABLE_ANDROID
// SaveToDisk
Status SaveToDisk::ValidateParams() {

View File

@ -126,23 +126,27 @@ class SaveToDisk : public TreeConsumer {
/// Consumer that iterates over the dataset and send it to a device
class ToDevice : public TreeConsumer {
public:
ToDevice(std::string device_type, bool send_epoch_end, int32_t num_epochs)
ToDevice(std::string device_type, bool send_epoch_end, int32_t num_epochs = -1)
: TreeConsumer(), device_type_(device_type), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {}
Status Init(std::shared_ptr<api::Dataset> d) override;
Status Send() {
// TODO(CRC): launch the tree
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
Status Stop() {
// TODO(CRC): Get root + call StopSend
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
Status Continue() {
// TODO(CRC): Get root + call StopSend
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
/// Send the data to device
/// \return Status error code
Status Send();
/// Stop to send data to device
/// \return Status error code
Status Stop();
/// Continue to send data to device
/// \return Status error code
Status Continue();
protected:
/// Method to return the name of the consumer
/// \return string
std::string Name() override { return "ToDevice"; }
private:
std::string device_type_;

View File

@ -15,6 +15,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
skip_node.cc
sync_wait_node.cc
take_node.cc
transfer_node.cc
zip_node.cc
)

View File

@ -68,6 +68,13 @@ std::vector<std::shared_ptr<DatasetOp>> AlbumNode::Build() {
return node_ops;
}
// Get the shard id of node
Status AlbumNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -44,6 +44,10 @@ class AlbumNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
std::string dataset_dir_;
std::string schema_path_;

View File

@ -67,6 +67,13 @@ std::vector<std::shared_ptr<DatasetOp>> CelebANode::Build() {
return node_ops;
}
// Get the shard id of node
Status CelebANode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -46,6 +46,10 @@ class CelebANode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
std::string dataset_dir_;
std::string usage_;

View File

@ -66,6 +66,13 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Node::Build() {
return node_ops;
}
// Get the shard id of node
Status Cifar100Node::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -44,6 +44,10 @@ class Cifar100Node : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
std::string dataset_dir_;
std::string usage_;

View File

@ -64,6 +64,13 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Node::Build() {
return node_ops;
}
// Get the shard id of node
Status Cifar10Node::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -44,6 +44,10 @@ class Cifar10Node : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
std::string dataset_dir_;
std::string usage_;

View File

@ -213,6 +213,13 @@ std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() {
return node_ops;
}
// Get the shard id of node
Status CLUENode::GetShardId(int32_t *shard_id) {
*shard_id = shard_id_;
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -45,6 +45,10 @@ class CLUENode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
/// \brief Split string based on a character delimiter
/// \return A string vector

View File

@ -117,6 +117,14 @@ std::vector<std::shared_ptr<DatasetOp>> CocoNode::Build() {
return node_ops;
}
// Get the shard id of node
Status CocoNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -43,6 +43,10 @@ class CocoNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
std::string dataset_dir_;
std::string annotation_file_;

View File

@ -122,6 +122,14 @@ std::vector<std::shared_ptr<DatasetOp>> CSVNode::Build() {
return node_ops;
}
// Get the shard id of node
Status CSVNode::GetShardId(int32_t *shard_id) {
*shard_id = shard_id_;
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -66,6 +66,10 @@ class CSVNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
std::vector<std::string> dataset_files_;
char field_delim_;

View File

@ -70,6 +70,14 @@ std::vector<std::shared_ptr<DatasetOp>> ImageFolderNode::Build() {
std::move(sampler_->Build())));
return node_ops;
}
// Get the shard id of node
Status ImageFolderNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -51,6 +51,10 @@ class ImageFolderNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
std::string dataset_dir_;
bool decode_;

View File

@ -85,6 +85,14 @@ std::vector<std::shared_ptr<DatasetOp>> ManifestNode::Build() {
return node_ops;
}
// Get the shard id of node
Status ManifestNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -44,6 +44,10 @@ class ManifestNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
std::string dataset_file_;
std::string usage_;

View File

@ -160,6 +160,13 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() {
return node_ops;
}
// Get the shard id of node
Status MindDataNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -48,6 +48,10 @@ class MindDataNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
/// \brief Build sampler chain for minddata dataset
/// \return Status Status::OK() if input sampler is valid
Status BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> &sampler,

View File

@ -60,6 +60,13 @@ std::vector<std::shared_ptr<DatasetOp>> MnistNode::Build() {
return node_ops;
}
// Get the shard id of node
Status MnistNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -44,6 +44,10 @@ class MnistNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
std::string dataset_dir_;
std::string usage_;

View File

@ -99,6 +99,13 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() {
return node_ops;
}
// Get the shard id of node
Status RandomNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -65,6 +65,10 @@ class RandomNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
/// \brief A quick inline for producing a random number between (and including) min/max
/// \param[in] min minimum number that can be generated.

View File

@ -95,6 +95,13 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() {
return node_ops;
}
// Get the shard id of node
Status TextFileNode::GetShardId(int32_t *shard_id) {
*shard_id = shard_id_;
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -45,6 +45,10 @@ class TextFileNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
std::vector<std::string> dataset_files_;
int32_t num_samples_;

View File

@ -80,6 +80,13 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() {
return node_ops;
}
// Get the shard id of node
Status TFRecordNode::GetShardId(int32_t *shard_id) {
*shard_id = shard_id_;
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -72,6 +72,10 @@ class TFRecordNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
std::vector<std::string> dataset_files_;
std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string

View File

@ -112,6 +112,13 @@ std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() {
return node_ops;
}
// Get the shard id of node
Status VOCNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -45,6 +45,10 @@ class VOCNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
const std::string kColumnImage = "image";
const std::string kColumnTarget = "target";

View File

@ -0,0 +1,90 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace api {
// Constructor for TransferNode
TransferNode::TransferNode(std::shared_ptr<Dataset> child, const std::string &queue_name, int32_t device_id,
const std::string &device_type, bool send_epoch_end)
: queue_name_(queue_name),
device_id_(device_id),
device_type_(device_type),
prefetch_size_(16),
send_epoch_end_(send_epoch_end),
total_batch_(0) {
this->children.push_back(child);
}
// Validator for TransferNode
Status TransferNode::ValidateParams() {
// Check if device_type_ is in {"CPU", "GPU", "Ascend"}
RETURN_IF_NOT_OK(ValidateStringValue("TransferNode", device_type_, {"CPU", "GPU", "Ascend"}));
return Status::OK();
}
// Function to build TransferNode
std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
// Convert device_type_ from string to DeviceType
DeviceQueueOp::DeviceType type;
if (device_type_ == "CPU") {
type = DeviceQueueOp::DeviceType::CPU;
} else if (device_type_ == "GPU") {
type = DeviceQueueOp::DeviceType::GPU;
} else if (device_type_ == "Ascend") {
type = DeviceQueueOp::DeviceType::Ascend;
}
node_ops.push_back(
std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_, total_batch_));
return node_ops;
}
// Function to get the device_id
Status TransferNode::get_distribution(std::shared_ptr<Dataset> ds, int32_t *device_id) {
// Get device id according to the type of dataset
Status rc = ds->GetShardId(device_id);
if (rc != Status::OK()) {
// Get device id from the child node
if (ds->children.size()) {
ds = ds->children[0];
return TransferNode::get_distribution(ds, device_id);
} else {
std::string err_msg = "Unknown dataset type.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TRANSFER_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TRANSFER_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/include/datasets.h"
namespace mindspore {
namespace dataset {
namespace api {
class TransferNode : public Dataset {
public:
/// \brief Constructor
TransferNode(std::shared_ptr<Dataset> child, const std::string &queue_name, int32_t device_id,
const std::string &device_type, bool send_epoch_end);
/// \brief Destructor
~TransferNode() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
static Status get_distribution(std::shared_ptr<Dataset> ds, int32_t *device_id);
private:
std::string queue_name_;
int32_t device_id_;
std::string device_type_;
int32_t prefetch_size_;
bool send_epoch_end_;
int32_t total_batch_;
};
} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TRANSFER_NODE_H_

View File

@ -57,6 +57,10 @@ class TreeAdapter {
// to be able to launch a thread. BuildAndPrepare needs to be called before this function
TaskGroup *AllTasks() const { return tree_ != nullptr ? tree_->AllTasks() : nullptr; }
std::shared_ptr<DatasetOp> root() { return tree_->root(); }
Status Launch() const { return tree_->Launch(); }
private:
// This RECURSIVE function converts IR nodes into DatasetOp in ExecutionTree. IR could build a vector of ops. In
// such case, the first node is returned. Op is added as child when the current function returns.

View File

@ -96,6 +96,7 @@ class RepeatNode;
class ShuffleNode;
class SkipNode;
class TakeNode;
class TransferNode;
class ZipNode;
#define RETURN_EMPTY_IF_ERROR(_s) \
@ -559,6 +560,7 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
public:
// need friend class so they can access the children_ field
friend class Iterator;
friend class TransferNode;
friend class mindspore::dataset::TreeAdapter;
/// \brief Constructor
@ -579,6 +581,12 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Status Status::OK() if all the parameters are valid
virtual Status ValidateParams() = 0;
/// \brief Pure virtual function for derived class to get the shard id of specific node
/// \return Status Status::OK() if get shard id successfully
virtual Status GetShardId(int32_t *shard_id) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
/// \brief Gets the dataset size
/// \return status code
int64_t GetDatasetSize();
@ -617,6 +625,13 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Shared pointer to the Iterator
std::shared_ptr<Iterator> CreateIterator(std::vector<std::string> columns = {});
/// \brief Function to transfer data through a device.
/// \notes If device is Ascend, features of data will be transferred one by one. The limitation
/// of data transmission per time is 256M.
/// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=True).
/// \return Returns true if no error encountered else false.
bool DeviceQueue(bool send_epoch_end = true);
#ifndef ENABLE_ANDROID
/// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline
/// \note Usage restrictions:

View File

@ -17,8 +17,9 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
#include <vector>
#include <memory>
#include <string>
#include <vector>
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
@ -48,6 +49,10 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
/// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<Sampler> Build() = 0;
/// \brief Function for derived class to get the shard id of sampler
/// \return The shard id of the derived sampler
virtual int64_t ShardId() { return 0; }
#ifndef ENABLE_ANDROID
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
@ -134,6 +139,10 @@ class DistributedSamplerObj : public SamplerObj {
bool ValidateParams() override;
/// \brief Function to get the shard id of sampler
/// \return The shard id of sampler
int64_t ShardId() override { return shard_id_; }
private:
int64_t num_shards_;
int64_t shard_id_;