forked from mindspore-Ecosystem/mindspore
MindRecord becomes a mappable dataset
updates review touches ci fixes ci fix ci fixin ci fix review updates further cleanup updates updates update update comment
This commit is contained in:
parent
dbb79b3e49
commit
515c936b85
|
@ -25,7 +25,7 @@
|
|||
#include "minddata/dataset/include/constants.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/util/log_adapter.h"
|
||||
|
@ -67,9 +67,13 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) {
|
|||
if (build_num_padded_ > 0) {
|
||||
sample_json = ToJson(build_sample_);
|
||||
}
|
||||
new_mind_record_op = std::make_shared<MindRecordOp>(
|
||||
build_num_mind_record_workers_, build_dataset_file_, build_load_dataset_, build_op_connector_queue_size_,
|
||||
build_columns_to_load_, build_operators_, build_num_padded_, sample_json, build_sample_bytes_);
|
||||
|
||||
std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>();
|
||||
|
||||
new_mind_record_op =
|
||||
std::make_shared<MindRecordOp>(build_num_mind_record_workers_, build_dataset_file_, build_load_dataset_,
|
||||
build_op_connector_queue_size_, build_columns_to_load_, build_operators_,
|
||||
build_num_padded_, sample_json, build_sample_bytes_, std::move(shard_reader));
|
||||
|
||||
RETURN_IF_NOT_OK(new_mind_record_op->Init());
|
||||
*ptr = std::move(new_mind_record_op);
|
||||
|
@ -110,8 +114,10 @@ mindrecord::json MindRecordOp::Builder::ToJson(const py::handle &obj) {
|
|||
MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector<std::string> dataset_file, bool load_dataset,
|
||||
int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load,
|
||||
const std::vector<std::shared_ptr<ShardOperator>> &operators, int64_t num_padded,
|
||||
const mindrecord::json &sample_json, const std::map<std::string, std::string> &sample_bytes)
|
||||
: MappableLeafOp(num_mind_record_workers, op_connector_queue_size, std::make_shared<SequentialSamplerRT>(0, 0)),
|
||||
const mindrecord::json &sample_json, const std::map<std::string, std::string> &sample_bytes,
|
||||
std::unique_ptr<ShardReader> shard_reader)
|
||||
: MappableLeafOp(num_mind_record_workers, op_connector_queue_size,
|
||||
std::make_shared<MindRecordSamplerRT>(shard_reader.get())),
|
||||
dataset_file_(dataset_file),
|
||||
load_dataset_(load_dataset),
|
||||
columns_to_load_(columns_to_load),
|
||||
|
@ -120,7 +126,8 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector<std::str
|
|||
ended_worker_(0),
|
||||
num_padded_(num_padded),
|
||||
sample_json_(sample_json),
|
||||
sample_bytes_(sample_bytes) {
|
||||
sample_bytes_(sample_bytes),
|
||||
shard_reader_(std::move(shard_reader)) {
|
||||
io_block_queues_.Init(num_workers_, op_connector_queue_size);
|
||||
epoch_sync_flag_ = true; // MindRecordOp needs to turn this flag on, otherwise, calling ShuffleTask() before all
|
||||
// tasks are consumed by the worker threads would cause problem.
|
||||
|
@ -128,7 +135,6 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector<std::str
|
|||
|
||||
// Private helper method to encapsulate some common construction/reset tasks
|
||||
Status MindRecordOp::Init() {
|
||||
shard_reader_ = std::make_unique<ShardReader>();
|
||||
auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_,
|
||||
num_padded_);
|
||||
|
||||
|
@ -363,9 +369,6 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint
|
|||
Status MindRecordOp::Reset() {
|
||||
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
|
||||
RETURN_IF_NOT_OK(MappableLeafOp::Reset()); // Call our super class reset first.
|
||||
|
||||
shard_reader_->ShuffleTask();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -140,7 +140,8 @@ class MindRecordOp : public MappableLeafOp {
|
|||
MindRecordOp(int32_t num_mind_record_workers, std::vector<std::string> dataset_file, bool load_dataset,
|
||||
int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load,
|
||||
const std::vector<std::shared_ptr<ShardOperator>> &operators, int64_t num_padded_,
|
||||
const mindrecord::json &sample_json, const std::map<std::string, std::string> &sample_bytes_);
|
||||
const mindrecord::json &sample_json, const std::map<std::string, std::string> &sample_bytes_,
|
||||
std::unique_ptr<ShardReader> shard_reader);
|
||||
|
||||
// Destructor
|
||||
~MindRecordOp() override;
|
||||
|
|
|
@ -10,6 +10,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES
|
|||
subset_random_sampler.cc
|
||||
subset_sampler.cc
|
||||
weighted_random_sampler.cc
|
||||
mind_record_sampler.cc
|
||||
)
|
||||
|
||||
if(ENABLE_PYTHON)
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
|
||||
#include "minddata/mindrecord/include/shard_reader.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
MindRecordSamplerRT::MindRecordSamplerRT(mindrecord::ShardReader *shard_reader, int64_t samples_per_tensor)
|
||||
: SamplerRT(0, samples_per_tensor), next_id_(0), shard_reader_(shard_reader) {}
|
||||
|
||||
Status MindRecordSamplerRT::GetNextSample(TensorRow *out) {
|
||||
if (next_id_ > num_samples_) {
|
||||
RETURN_STATUS_UNEXPECTED("MindRecordSampler Internal Error");
|
||||
} else if (next_id_ == num_samples_) {
|
||||
(*out) = TensorRow(TensorRow::kFlagEOE);
|
||||
} else {
|
||||
std::shared_ptr<Tensor> sampleIdsTensor;
|
||||
int64_t last_id = std::min(samples_per_tensor_ + next_id_, num_samples_);
|
||||
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIdsTensor, last_id - next_id_));
|
||||
auto id_ptr = sampleIdsTensor->begin<int64_t>();
|
||||
for (int64_t i = 0; i < (last_id - next_id_); i++) {
|
||||
*(id_ptr + static_cast<ptrdiff_t>(i)) = (*sample_ids_)[i];
|
||||
}
|
||||
next_id_ = last_id;
|
||||
|
||||
(*out) = {sampleIdsTensor};
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MindRecordSamplerRT::InitSampler() {
|
||||
sample_ids_ = shard_reader_->GetSampleIds();
|
||||
|
||||
if (!sample_ids_) {
|
||||
// Note, sample_ids_.empty() is okay and will just give no sample ids.
|
||||
RETURN_STATUS_UNEXPECTED("ShardReader did not provide a valid sample ids vector via MindRecordSamplerRT");
|
||||
}
|
||||
|
||||
// Usually, the num samples is given from the user interface. In our case, that data is in mindrecord.
|
||||
// Mindrecord already created the sample ids at this point, so the num samples is the size of the sampled id list.
|
||||
num_samples_ = sample_ids_->size();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MindRecordSamplerRT::ResetSampler() {
|
||||
// drive the shard reader reshuffle tasks to redo the sampling for another epoch
|
||||
next_id_ = 0;
|
||||
shard_reader_->ShuffleTask();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void MindRecordSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
|
||||
out << "\nSampler: MindRecordSampler";
|
||||
if (show_all) {
|
||||
// Call the super class for displaying any common detailed info
|
||||
SamplerRT::SamplerPrint(out, show_all);
|
||||
// Then add our own info if any
|
||||
}
|
||||
}
|
||||
|
||||
Status MindRecordSamplerRT::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sampler_name"] = "MindRecordSampler";
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_MINDRECORD_SAMPLER_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_MINDRECORD_SAMPLER_H_
|
||||
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/mindrecord/include/shard_reader.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class MindRecordSamplerRT : public SamplerRT {
|
||||
public:
|
||||
// Constructor
|
||||
// @param shard_reader - shard_reader
|
||||
// @param int64_t samples_per_tensor - Num of Sampler Ids to fetch via 1 GetNextSample call
|
||||
MindRecordSamplerRT(mindrecord::ShardReader *shard_reader,
|
||||
int64_t samples_per_tensor = std::numeric_limits<int64_t>::max());
|
||||
|
||||
// Destructor.
|
||||
~MindRecordSamplerRT() = default;
|
||||
|
||||
// Op calls this to get next set of sampleIds
|
||||
// @param out - Tensor of sample ids to be returned to caller
|
||||
// @return Status The status code returned
|
||||
Status GetNextSample(TensorRow *out) override;
|
||||
|
||||
// meant to be called by base class or python
|
||||
Status InitSampler() override;
|
||||
|
||||
// for next epoch of sampleIds
|
||||
// @return Status The status code returned
|
||||
Status ResetSampler() override;
|
||||
|
||||
void SamplerPrint(std::ostream &out, bool show_all) const override;
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
mindrecord::ShardReader *shard_reader_; // back pointer to the shard reader
|
||||
const std::vector<int> *sample_ids_; // read-only back pointer into mind record sampler ids
|
||||
int64_t next_id_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_MINDRECORD_SAMPLER_H_
|
|
@ -23,6 +23,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
|
@ -155,17 +156,19 @@ Status MindDataNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
RETURN_IF_NOT_OK(BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_));
|
||||
|
||||
std::shared_ptr<MindRecordOp> mindrecord_op;
|
||||
std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>();
|
||||
|
||||
// If pass a string to MindData(), it will be treated as a pattern to search for matched files,
|
||||
// else if pass a vector to MindData(), it will be treated as specified files to be read
|
||||
if (search_for_pattern_) {
|
||||
std::vector<std::string> dataset_file_vec_ = {dataset_file_};
|
||||
mindrecord_op =
|
||||
std::make_shared<MindRecordOp>(num_workers_, dataset_file_vec_, search_for_pattern_, connector_que_size_,
|
||||
columns_list_, operators_, num_padded_, padded_sample_, sample_bytes_);
|
||||
mindrecord_op = std::make_shared<MindRecordOp>(num_workers_, dataset_file_vec_, search_for_pattern_,
|
||||
connector_que_size_, columns_list_, operators_, num_padded_,
|
||||
padded_sample_, sample_bytes_, std::move(shard_reader));
|
||||
} else {
|
||||
mindrecord_op =
|
||||
std::make_shared<MindRecordOp>(num_workers_, dataset_files_, search_for_pattern_, connector_que_size_,
|
||||
columns_list_, operators_, num_padded_, padded_sample_, sample_bytes_);
|
||||
mindrecord_op = std::make_shared<MindRecordOp>(num_workers_, dataset_files_, search_for_pattern_,
|
||||
connector_que_size_, columns_list_, operators_, num_padded_,
|
||||
padded_sample_, sample_bytes_, std::move(shard_reader));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(mindrecord_op->Init());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -46,7 +46,7 @@ class __attribute__((visibility("default"))) ShardCategory : public ShardOperato
|
|||
|
||||
bool GetReplacement() const { return replacement_; }
|
||||
|
||||
MSRStatus Execute(ShardTask &tasks) override;
|
||||
MSRStatus Execute(ShardTaskList &tasks) override;
|
||||
|
||||
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -39,15 +39,15 @@ class __attribute__((visibility("default"))) ShardDistributedSample : public Sha
|
|||
|
||||
~ShardDistributedSample() override{};
|
||||
|
||||
MSRStatus PreExecute(ShardTask &tasks) override;
|
||||
MSRStatus PreExecute(ShardTaskList &tasks) override;
|
||||
|
||||
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;
|
||||
|
||||
private:
|
||||
bool shuffle_;
|
||||
int no_of_padded_samples_;
|
||||
bool first_epoch_; // check (num_sample + num_padded) % num_shards == 0 in first epoch
|
||||
ShardTask task_; // maintain the input tasks in first epoch
|
||||
bool first_epoch_; // check (num_sample + num_padded) % num_shards == 0 in first epoch
|
||||
ShardTaskList task_; // maintain the input tasks in first epoch
|
||||
};
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,7 +18,7 @@
|
|||
#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_
|
||||
|
||||
#include <memory>
|
||||
#include "minddata/mindrecord/include/shard_task.h"
|
||||
#include "minddata/mindrecord/include/shard_task_list.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace mindrecord {
|
||||
|
@ -26,7 +26,7 @@ class __attribute__((visibility("default"))) ShardOperator {
|
|||
public:
|
||||
virtual ~ShardOperator() = default;
|
||||
|
||||
MSRStatus operator()(ShardTask &tasks) {
|
||||
MSRStatus operator()(ShardTaskList &tasks) {
|
||||
if (SUCCESS != this->PreExecute(tasks)) {
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -47,11 +47,11 @@ class __attribute__((visibility("default"))) ShardOperator {
|
|||
|
||||
virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; }
|
||||
|
||||
virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; }
|
||||
virtual MSRStatus PreExecute(ShardTaskList &tasks) { return SUCCESS; }
|
||||
|
||||
virtual MSRStatus Execute(ShardTask &tasks) = 0;
|
||||
virtual MSRStatus Execute(ShardTaskList &tasks) = 0;
|
||||
|
||||
virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; }
|
||||
virtual MSRStatus SufExecute(ShardTaskList &tasks) { return SUCCESS; }
|
||||
|
||||
virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; }
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -38,7 +38,7 @@ class __attribute__((visibility("default"))) ShardPkSample : public ShardCategor
|
|||
|
||||
~ShardPkSample() override{};
|
||||
|
||||
MSRStatus SufExecute(ShardTask &tasks) override;
|
||||
MSRStatus SufExecute(ShardTaskList &tasks) override;
|
||||
|
||||
int64_t GetNumSamples() const { return num_samples_; }
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -210,6 +210,9 @@ class API_PUBLIC ShardReader {
|
|||
/// \brief get the size of blob data
|
||||
MSRStatus GetTotalBlobSize(int64_t *total_blob_size);
|
||||
|
||||
/// \brief get a read-only ptr to the sampled ids for this epoch
|
||||
const std::vector<int> *GetSampleIds();
|
||||
|
||||
protected:
|
||||
/// \brief sqlite call back function
|
||||
static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names);
|
||||
|
@ -322,7 +325,7 @@ class API_PUBLIC ShardReader {
|
|||
std::vector<std::string> selected_columns_; // columns which will be read
|
||||
std::map<string, uint64_t> column_schema_id_; // column-schema map
|
||||
std::vector<std::shared_ptr<ShardOperator>> operators_; // data operators, including shuffle, sample and category
|
||||
ShardTask tasks_; // shard task
|
||||
ShardTaskList tasks_; // shard task list
|
||||
std::mutex shard_locker_; // locker of shard
|
||||
|
||||
// flags
|
||||
|
@ -339,7 +342,7 @@ class API_PUBLIC ShardReader {
|
|||
std::mutex mtx_delivery_; // locker for delivery
|
||||
std::condition_variable cv_delivery_; // conditional variable for delivery
|
||||
std::condition_variable cv_iterator_; // conditional variable for iterator
|
||||
std::atomic<int> task_id_; // task ID which is working
|
||||
std::atomic<int> sample_id_position_; // index into the sample ids vector for the current sample id
|
||||
std::atomic<int> deliver_id_; // delivery ID which is picked up by iterator
|
||||
// map of delivery
|
||||
std::unordered_map<int, std::shared_ptr<std::vector<std::tuple<std::vector<uint8_t>, json>>>> delivery_map_;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -40,11 +40,11 @@ class __attribute__((visibility("default"))) ShardSample : public ShardOperator
|
|||
|
||||
~ShardSample() override{};
|
||||
|
||||
MSRStatus Execute(ShardTask &tasks) override;
|
||||
MSRStatus Execute(ShardTaskList &tasks) override;
|
||||
|
||||
MSRStatus UpdateTasks(ShardTask &tasks, int taking);
|
||||
MSRStatus UpdateTasks(ShardTaskList &tasks, int taking);
|
||||
|
||||
MSRStatus SufExecute(ShardTask &tasks) override;
|
||||
MSRStatus SufExecute(ShardTaskList &tasks) override;
|
||||
|
||||
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -33,7 +33,7 @@ class __attribute__((visibility("default"))) ShardSequentialSample : public Shar
|
|||
|
||||
~ShardSequentialSample() override{};
|
||||
|
||||
MSRStatus Execute(ShardTask &tasks) override;
|
||||
MSRStatus Execute(ShardTaskList &tasks) override;
|
||||
|
||||
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -31,11 +31,14 @@ class __attribute__((visibility("default"))) ShardShuffle : public ShardOperator
|
|||
|
||||
~ShardShuffle() override{};
|
||||
|
||||
MSRStatus Execute(ShardTask &tasks) override;
|
||||
MSRStatus Execute(ShardTaskList &tasks) override;
|
||||
|
||||
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;
|
||||
|
||||
private:
|
||||
// Private helper function
|
||||
MSRStatus CategoryShuffle(ShardTaskList &tasks);
|
||||
|
||||
uint32_t shuffle_seed_;
|
||||
int64_t no_of_samples_;
|
||||
bool replacement_;
|
||||
|
|
|
@ -1,109 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "minddata/mindrecord/include/common/shard_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace mindrecord {
|
||||
class __attribute__((visibility("default"))) ShardTask {
|
||||
public:
|
||||
ShardTask();
|
||||
|
||||
ShardTask(const ShardTask &task); // copy construction
|
||||
|
||||
ShardTask &operator=(const ShardTask &task); // assignment operator
|
||||
|
||||
~ShardTask() = default;
|
||||
|
||||
void MakePerm();
|
||||
|
||||
inline void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset,
|
||||
const json &label);
|
||||
|
||||
inline void InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id,
|
||||
const std::vector<uint64_t> &offset, const json &label);
|
||||
|
||||
inline void InsertTask(std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task);
|
||||
|
||||
inline void InsertTask(const uint32_t &i,
|
||||
std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task);
|
||||
|
||||
void PopBack();
|
||||
|
||||
uint32_t Size() const;
|
||||
|
||||
uint32_t SizeOfRows() const;
|
||||
|
||||
std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &GetTaskByID(size_t id);
|
||||
|
||||
std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &GetRandomTask();
|
||||
|
||||
static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements,
|
||||
int64_t num_samples);
|
||||
|
||||
inline void ResizeTask(const uint32_t &size);
|
||||
|
||||
uint32_t categories;
|
||||
|
||||
// The total sample ids which used to shuffle operation. The ids like: [0, 1, 2, 3, ..., n-1, n]
|
||||
std::vector<int> permutation_;
|
||||
|
||||
// The data struct is as below:
|
||||
// 1. TaskType: kCommonTask / kPaddedTask
|
||||
// 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load)
|
||||
// 3. std::vector<uint64_t>, json>> : [blob_start, blob_end], scalar_variable_fields
|
||||
std::vector<std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>> task_list_;
|
||||
};
|
||||
|
||||
inline void ShardTask::InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset,
|
||||
const json &label) {
|
||||
MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id
|
||||
<< ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
|
||||
task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label);
|
||||
}
|
||||
|
||||
inline void ShardTask::InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id,
|
||||
const std::vector<uint64_t> &offset, const json &label) {
|
||||
task_list_[i] = {task_type, std::make_tuple(shard_id, group_id), offset, label};
|
||||
}
|
||||
|
||||
inline void ShardTask::InsertTask(std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task) {
|
||||
MS_LOG(DEBUG) << "Into insert task, shard_id: " << std::get<0>(std::get<1>(task))
|
||||
<< ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump()
|
||||
<< ", size of task_list_: " << task_list_.size() << ".";
|
||||
|
||||
task_list_.push_back(std::move(task));
|
||||
}
|
||||
|
||||
inline void ShardTask::InsertTask(const uint32_t &i,
|
||||
std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task) {
|
||||
task_list_[i] = std::move(task);
|
||||
}
|
||||
|
||||
inline void ShardTask::ResizeTask(const uint32_t &size) { task_list_.resize(size); }
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
|
|
@ -0,0 +1,132 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "minddata/mindrecord/include/common/shard_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace mindrecord {
|
||||
|
||||
// The data struct is as below:
|
||||
// 1. TaskType: kCommonTask / kPaddedTask
|
||||
// 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load)
|
||||
// 3. std::vector<uint64_t>, json>> : [blob_start, blob_end], scalar_variable_fields
|
||||
using ShardTask = std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>;
|
||||
|
||||
class __attribute__((visibility("default"))) ShardTaskList {
|
||||
public:
|
||||
ShardTaskList();
|
||||
|
||||
ShardTaskList(const ShardTaskList &task); // copy construction
|
||||
|
||||
ShardTaskList &operator=(const ShardTaskList &task); // assignment operator
|
||||
|
||||
~ShardTaskList() = default;
|
||||
|
||||
void InitSampleIds();
|
||||
|
||||
static void TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_tasks);
|
||||
|
||||
// Assigns the task based on task id
|
||||
inline void AssignTask(ShardTaskList &sourceTasks, size_t id);
|
||||
|
||||
inline void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset,
|
||||
const json &label);
|
||||
|
||||
inline void InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id,
|
||||
const std::vector<uint64_t> &offset, const json &label);
|
||||
|
||||
inline void InsertTask(ShardTask task);
|
||||
|
||||
inline void InsertTask(const uint32_t &i, ShardTask task);
|
||||
|
||||
void MakePerm();
|
||||
|
||||
inline void InsertSampleId(int id);
|
||||
|
||||
void PopBack();
|
||||
|
||||
uint32_t Size() const;
|
||||
|
||||
uint32_t SizeOfRows() const;
|
||||
|
||||
ShardTask &GetTaskByID(size_t id);
|
||||
|
||||
ShardTask &GetRandomTask();
|
||||
|
||||
int GetTaskSampleByID(size_t id);
|
||||
|
||||
int GetRandomTaskID();
|
||||
|
||||
static ShardTaskList Combine(std::vector<ShardTaskList> &category_tasks, bool replacement, int64_t num_elements,
|
||||
int64_t num_samples);
|
||||
|
||||
inline void ResizeTask(const uint32_t &size);
|
||||
|
||||
uint32_t categories;
|
||||
|
||||
std::vector<int> permutation_; // A list of ints used for shuffling sample ids
|
||||
|
||||
std::vector<int> sample_ids_; // The list of actual ids that were sampled
|
||||
|
||||
std::vector<ShardTask> task_list_; // The full list of tasks
|
||||
};
|
||||
|
||||
inline void ShardTaskList::AssignTask(ShardTaskList &sourceTasks, size_t id) {
|
||||
// Insert the sample id from the source into ourself by indexing at id position.
|
||||
// Important: The task list itself does not change.
|
||||
int sample_id = sourceTasks.GetTaskSampleByID(id);
|
||||
MS_LOG(DEBUG) << "Insert sample id (" << sample_id << ") into task list from source task position: " << id;
|
||||
sample_ids_.push_back(sample_id);
|
||||
}
|
||||
|
||||
inline void ShardTaskList::InsertTask(TaskType task_type, int shard_id, int group_id,
|
||||
const std::vector<uint64_t> &offset, const json &label) {
|
||||
MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id
|
||||
<< ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
|
||||
task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label);
|
||||
}
|
||||
|
||||
inline void ShardTaskList::InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id,
|
||||
const std::vector<uint64_t> &offset, const json &label) {
|
||||
MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id
|
||||
<< ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
|
||||
task_list_[i] = {task_type, std::make_tuple(shard_id, group_id), offset, label};
|
||||
}
|
||||
|
||||
inline void ShardTaskList::InsertTask(ShardTask task) {
|
||||
MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << std::get<0>(std::get<1>(task))
|
||||
<< ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump()
|
||||
<< ", size of task_list_: " << task_list_.size() << ".";
|
||||
|
||||
task_list_.push_back(std::move(task));
|
||||
}
|
||||
|
||||
inline void ShardTaskList::InsertTask(const uint32_t &i, ShardTask task) { task_list_[i] = std::move(task); }
|
||||
|
||||
inline void ShardTaskList::ResizeTask(const uint32_t &size) { task_list_.resize(size); }
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -46,7 +46,7 @@ ShardReader::ShardReader()
|
|||
num_padded_(0),
|
||||
num_rows_(0),
|
||||
total_blob_size_(0),
|
||||
task_id_(0),
|
||||
sample_id_position_(0),
|
||||
deliver_id_(0),
|
||||
lazy_load_(false),
|
||||
shard_sample_count_() {}
|
||||
|
@ -1088,9 +1088,8 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator
|
|||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
// Generate task list, a task will create a batch
|
||||
std::vector<ShardTask> categoryTasks(categories.size());
|
||||
// Generate a vector of task lists. Each catogory has a list of tasks.
|
||||
std::vector<ShardTaskList> categoryTasks(categories.size());
|
||||
for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) {
|
||||
int category_index = 0;
|
||||
for (int shard_id = 0; shard_id < shard_count_ && category_index < num_elements; ++shard_id) {
|
||||
|
@ -1122,7 +1121,9 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator
|
|||
}
|
||||
}
|
||||
}
|
||||
tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples);
|
||||
tasks_ = ShardTaskList::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples);
|
||||
|
||||
tasks_.InitSampleIds();
|
||||
if (SUCCESS != (*category_op)(tasks_)) {
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -1246,6 +1247,10 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u
|
|||
}
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Created initial list of tasks. There are " << tasks_.Size() << " to start with before sampling.";
|
||||
|
||||
tasks_.InitSampleIds();
|
||||
|
||||
for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) {
|
||||
const auto &op = operators[operator_no];
|
||||
if (std::dynamic_pointer_cast<ShardCategory>(op)) continue;
|
||||
|
@ -1256,7 +1261,9 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u
|
|||
|
||||
if (tasks_.permutation_.empty()) tasks_.MakePerm();
|
||||
num_rows_ = tasks_.Size();
|
||||
MS_LOG(INFO) << "Total rows is " << num_rows_;
|
||||
MS_LOG(INFO) << "Total rows is " << num_rows_
|
||||
<< " and total amount sampled initially is: " << tasks_.sample_ids_.size();
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -1272,9 +1279,9 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
|
|||
uint32_t blob_start = 0;
|
||||
uint32_t blob_end = 0;
|
||||
json var_fields;
|
||||
|
||||
// Pick up task from task list
|
||||
auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]);
|
||||
ShardTask task;
|
||||
task = tasks_.GetTaskByID(task_id);
|
||||
|
||||
// check task type
|
||||
auto task_type = std::get<0>(task);
|
||||
|
@ -1354,16 +1361,16 @@ MSRStatus ShardReader::ConsumerByRow(int consumer_id) {
|
|||
|
||||
// Loop forever
|
||||
for (;;) {
|
||||
int task_id = 0;
|
||||
int sample_id_pos = 0;
|
||||
|
||||
// Get next task ID
|
||||
task_id = task_id_++;
|
||||
sample_id_pos = sample_id_position_++;
|
||||
|
||||
// All tasks are done
|
||||
if (task_id >= static_cast<int>(tasks_.Size())) {
|
||||
if (sample_id_pos >= static_cast<int>(tasks_.sample_ids_.size())) {
|
||||
return FAILED;
|
||||
}
|
||||
const auto &ret = ConsumerOneTask(task_id, consumer_id);
|
||||
const auto &ret = ConsumerOneTask(tasks_.sample_ids_[sample_id_pos], consumer_id);
|
||||
if (SUCCESS != ret.first) {
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -1372,11 +1379,13 @@ MSRStatus ShardReader::ConsumerByRow(int consumer_id) {
|
|||
// otherwise, set batch data in map
|
||||
{
|
||||
std::unique_lock<std::mutex> lck(mtx_delivery_);
|
||||
cv_delivery_.wait(lck, [task_id, this] { return interrupt_ || task_id <= deliver_id_ + kNumBatchInMap; });
|
||||
cv_delivery_.wait(lck,
|
||||
[sample_id_pos, this] { return interrupt_ || sample_id_pos <= deliver_id_ + kNumBatchInMap; });
|
||||
if (interrupt_) {
|
||||
return SUCCESS;
|
||||
}
|
||||
delivery_map_[task_id] = std::make_shared<std::vector<std::tuple<std::vector<uint8_t>, json>>>(std::move(batch));
|
||||
delivery_map_[sample_id_pos] =
|
||||
std::make_shared<std::vector<std::tuple<std::vector<uint8_t>, json>>>(std::move(batch));
|
||||
}
|
||||
cv_iterator_.notify_one();
|
||||
}
|
||||
|
@ -1386,7 +1395,7 @@ std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNext() {
|
|||
if (interrupt_) {
|
||||
return std::vector<std::tuple<std::vector<uint8_t>, json>>();
|
||||
}
|
||||
if (deliver_id_ >= static_cast<int>(tasks_.Size())) {
|
||||
if (deliver_id_ >= static_cast<int>(tasks_.sample_ids_.size())) {
|
||||
return std::vector<std::tuple<std::vector<uint8_t>, json>>();
|
||||
}
|
||||
|
||||
|
@ -1458,7 +1467,7 @@ std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> Sha
|
|||
void ShardReader::Reset() {
|
||||
{
|
||||
std::lock_guard<std::mutex> lck(mtx_delivery_);
|
||||
task_id_ = 0;
|
||||
sample_id_position_ = 0;
|
||||
deliver_id_ = 0;
|
||||
}
|
||||
cv_delivery_.notify_all();
|
||||
|
@ -1486,5 +1495,10 @@ void ShardReader::ShuffleTask() {
|
|||
if (tasks_.permutation_.empty()) tasks_.MakePerm();
|
||||
}
|
||||
|
||||
const std::vector<int> *ShardReader::GetSampleIds() {
|
||||
// return const reference to private sample id list.
|
||||
return &(this->tasks_.sample_ids_);
|
||||
}
|
||||
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -34,7 +34,7 @@ ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elem
|
|||
num_categories_(num_categories),
|
||||
replacement_(replacement) {}
|
||||
|
||||
MSRStatus ShardCategory::Execute(ShardTask &tasks) { return SUCCESS; }
|
||||
MSRStatus ShardCategory::Execute(ShardTaskList &tasks) { return SUCCESS; }
|
||||
|
||||
int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
||||
if (dataset_size == 0) return dataset_size;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -55,7 +55,7 @@ int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_
|
|||
return 0;
|
||||
}
|
||||
|
||||
MSRStatus ShardDistributedSample::PreExecute(ShardTask &tasks) {
|
||||
MSRStatus ShardDistributedSample::PreExecute(ShardTaskList &tasks) {
|
||||
auto total_no = tasks.Size();
|
||||
if (no_of_padded_samples_ > 0 && first_epoch_) {
|
||||
if (total_no % denominator_ != 0) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -37,7 +37,7 @@ ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elem
|
|||
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement
|
||||
}
|
||||
|
||||
MSRStatus ShardPkSample::SufExecute(ShardTask &tasks) {
|
||||
MSRStatus ShardPkSample::SufExecute(ShardTaskList &tasks) {
|
||||
if (shuffle_ == true) {
|
||||
if (SUCCESS != (*shuffle_op_)(tasks)) {
|
||||
return FAILED;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -80,21 +80,21 @@ int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
MSRStatus ShardSample::UpdateTasks(ShardTask &tasks, int taking) {
|
||||
MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) {
|
||||
if (tasks.permutation_.empty()) {
|
||||
ShardTask new_tasks;
|
||||
int total_no = static_cast<int>(tasks.Size());
|
||||
ShardTaskList new_tasks;
|
||||
int total_no = static_cast<int>(tasks.sample_ids_.size());
|
||||
if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) {
|
||||
for (int i = 0; i < indices_.size(); ++i) {
|
||||
int index = ((indices_[i] % total_no) + total_no) % total_no;
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python
|
||||
new_tasks.AssignTask(tasks, index); // different mod result between c and python
|
||||
}
|
||||
} else {
|
||||
int count = 0;
|
||||
if (nums_per_shard_.empty()) {
|
||||
for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
|
||||
if (no_of_samples_ != 0 && count == no_of_samples_) break;
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start
|
||||
new_tasks.AssignTask(tasks, i % total_no); // rounding up. if overflow, go back to start
|
||||
count++;
|
||||
}
|
||||
} else {
|
||||
|
@ -102,33 +102,33 @@ MSRStatus ShardSample::UpdateTasks(ShardTask &tasks, int taking) {
|
|||
size_t i = partition_id_ - 1 >= 0 ? nums_per_shard_[partition_id_ - 1] : 0;
|
||||
for (; i < nums_per_shard_[partition_id_]; i++) {
|
||||
if (no_of_samples_ != 0 && count == no_of_samples_) break;
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no));
|
||||
new_tasks.AssignTask(tasks, i % total_no);
|
||||
count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
std::swap(tasks, new_tasks);
|
||||
ShardTaskList::TaskListSwap(tasks, new_tasks);
|
||||
} else {
|
||||
ShardTask new_tasks;
|
||||
if (taking > static_cast<int>(tasks.permutation_.size())) {
|
||||
ShardTaskList new_tasks;
|
||||
if (taking > static_cast<int>(tasks.sample_ids_.size())) {
|
||||
return FAILED;
|
||||
}
|
||||
int total_no = static_cast<int>(tasks.permutation_.size());
|
||||
int count = 0;
|
||||
for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
|
||||
if (no_of_samples_ != 0 && count == no_of_samples_) break;
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no]));
|
||||
new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]);
|
||||
count++;
|
||||
}
|
||||
std::swap(tasks, new_tasks);
|
||||
ShardTaskList::TaskListSwap(tasks, new_tasks);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardSample::Execute(ShardTask &tasks) {
|
||||
MSRStatus ShardSample::Execute(ShardTaskList &tasks) {
|
||||
if (offset_ != -1) {
|
||||
int64_t old_v = 0;
|
||||
int num_rows_ = static_cast<int>(tasks.Size());
|
||||
int num_rows_ = static_cast<int>(tasks.sample_ids_.size());
|
||||
for (int x = 0; x < denominator_; x++) {
|
||||
int samples_per_buffer_ = (num_rows_ + offset_) / denominator_;
|
||||
int remainder = (num_rows_ + offset_) % denominator_;
|
||||
|
@ -140,8 +140,7 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) {
|
|||
}
|
||||
}
|
||||
int no_of_categories = static_cast<int>(tasks.categories);
|
||||
int total_no = static_cast<int>(tasks.Size()); // make sure task_size
|
||||
|
||||
int total_no = static_cast<int>(tasks.sample_ids_.size());
|
||||
int taking = 0;
|
||||
if (sampler_type_ == kCustomTopNSampler) { // non sharding case constructor #1
|
||||
no_of_samples_ = std::min(no_of_samples_, total_no);
|
||||
|
@ -167,7 +166,7 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) {
|
|||
return UpdateTasks(tasks, taking);
|
||||
}
|
||||
|
||||
MSRStatus ShardSample::SufExecute(ShardTask &tasks) {
|
||||
MSRStatus ShardSample::SufExecute(ShardTaskList &tasks) {
|
||||
if (sampler_type_ == kSubsetRandomSampler) {
|
||||
if (SUCCESS != (*shuffle_op_)(tasks)) {
|
||||
return FAILED;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -38,9 +38,9 @@ int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_c
|
|||
return std::min(static_cast<int64_t>(no_of_samples_), dataset_size);
|
||||
}
|
||||
|
||||
MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) {
|
||||
int64_t total_no = static_cast<int64_t>(tasks.Size());
|
||||
MSRStatus ShardSequentialSample::Execute(ShardTaskList &tasks) {
|
||||
int64_t taking;
|
||||
int64_t total_no = static_cast<int64_t>(tasks.sample_ids_.size());
|
||||
if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) {
|
||||
taking = total_no;
|
||||
} else if (per_ > kEpsilon && per_ <= 1.0f) {
|
||||
|
@ -50,22 +50,22 @@ MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) {
|
|||
}
|
||||
|
||||
if (tasks.permutation_.empty()) {
|
||||
ShardTask new_tasks;
|
||||
ShardTaskList new_tasks;
|
||||
total_no = static_cast<int64_t>(tasks.Size());
|
||||
for (size_t i = offset_; i < taking + offset_; ++i) {
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no));
|
||||
new_tasks.AssignTask(tasks, i % total_no);
|
||||
}
|
||||
std::swap(tasks, new_tasks);
|
||||
ShardTaskList::TaskListSwap(tasks, new_tasks);
|
||||
} else { // shuffled
|
||||
ShardTask new_tasks;
|
||||
ShardTaskList new_tasks;
|
||||
if (taking > static_cast<int64_t>(tasks.permutation_.size())) {
|
||||
return FAILED;
|
||||
}
|
||||
total_no = static_cast<int64_t>(tasks.permutation_.size());
|
||||
for (size_t i = offset_; i < taking + offset_; ++i) {
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no]));
|
||||
new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]);
|
||||
}
|
||||
std::swap(tasks, new_tasks);
|
||||
ShardTaskList::TaskListSwap(tasks, new_tasks);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -42,7 +42,31 @@ int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
|||
return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_);
|
||||
}
|
||||
|
||||
MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
|
||||
MSRStatus ShardShuffle::CategoryShuffle(ShardTaskList &tasks) {
|
||||
uint32_t individual_size;
|
||||
individual_size = tasks.sample_ids_.size() / tasks.categories;
|
||||
std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size));
|
||||
for (uint32_t i = 0; i < tasks.categories; i++) {
|
||||
for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast<int>(j);
|
||||
std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_));
|
||||
}
|
||||
tasks.permutation_.clear(); // Jamie replace this we setting flag to false or something
|
||||
for (uint32_t j = 0; j < individual_size; j++) {
|
||||
for (uint32_t i = 0; i < tasks.categories; i++) {
|
||||
tasks.permutation_.push_back(new_permutations[i][j] * static_cast<int>(tasks.categories) + static_cast<int>(i));
|
||||
}
|
||||
}
|
||||
|
||||
ShardTaskList new_tasks;
|
||||
for (size_t i = 0; i < individual_size; ++i) {
|
||||
new_tasks.AssignTask(tasks, tasks.permutation_[i]);
|
||||
}
|
||||
ShardTaskList::TaskListSwap(tasks, new_tasks);
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) {
|
||||
if (reshuffle_each_epoch_) shuffle_seed_++;
|
||||
if (tasks.categories < 1) {
|
||||
return FAILED;
|
||||
|
@ -52,43 +76,31 @@ MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
|
|||
tasks.MakePerm();
|
||||
}
|
||||
if (replacement_ == true) {
|
||||
ShardTask new_tasks;
|
||||
if (no_of_samples_ == 0) {
|
||||
no_of_samples_ = static_cast<int>(tasks.Size());
|
||||
}
|
||||
ShardTaskList new_tasks;
|
||||
if (no_of_samples_ == 0) no_of_samples_ = static_cast<int>(tasks.sample_ids_.size());
|
||||
if (no_of_samples_ <= 0) {
|
||||
MS_LOG(ERROR) << "no_of_samples need to be positive.";
|
||||
return FAILED;
|
||||
}
|
||||
new_tasks.task_list_.reserve(no_of_samples_);
|
||||
for (uint32_t i = 0; i < no_of_samples_; ++i) {
|
||||
new_tasks.InsertTask(tasks.GetRandomTask());
|
||||
new_tasks.AssignTask(tasks, tasks.GetRandomTaskID());
|
||||
}
|
||||
std::swap(tasks, new_tasks);
|
||||
|
||||
ShardTaskList::TaskListSwap(tasks, new_tasks);
|
||||
} else {
|
||||
std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_));
|
||||
auto total_no = static_cast<int64_t>(tasks.Size());
|
||||
if (no_of_samples_ > 0 && no_of_samples_ < total_no) {
|
||||
ShardTask new_tasks;
|
||||
for (size_t i = 0; i < no_of_samples_; ++i) {
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(i));
|
||||
}
|
||||
std::swap(tasks, new_tasks);
|
||||
ShardTaskList new_tasks;
|
||||
size_t samples_to_assign =
|
||||
(no_of_samples_ > 0 && no_of_samples_ < total_no) ? no_of_samples_ : tasks.sample_ids_.size();
|
||||
for (size_t i = 0; i < samples_to_assign; ++i) {
|
||||
new_tasks.AssignTask(tasks, tasks.permutation_[i]);
|
||||
}
|
||||
ShardTaskList::TaskListSwap(tasks, new_tasks);
|
||||
}
|
||||
} else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn)
|
||||
uint32_t individual_size = tasks.Size() / tasks.categories;
|
||||
std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size));
|
||||
for (uint32_t i = 0; i < tasks.categories; i++) {
|
||||
for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast<int>(j);
|
||||
std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_));
|
||||
}
|
||||
tasks.permutation_.clear();
|
||||
for (uint32_t j = 0; j < individual_size; j++) {
|
||||
for (uint32_t i = 0; i < tasks.categories; i++) {
|
||||
tasks.permutation_.push_back(new_permutations[i][j] * static_cast<int>(tasks.categories) + static_cast<int>(i));
|
||||
}
|
||||
}
|
||||
return this->CategoryShuffle(tasks);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/mindrecord/include/shard_task.h"
|
||||
#include "minddata/mindrecord/include/shard_task_list.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "minddata/mindrecord/include/common/shard_utils.h"
|
||||
|
||||
|
@ -25,55 +25,88 @@ using mindspore::MsLogLevel::DEBUG;
|
|||
|
||||
namespace mindspore {
|
||||
namespace mindrecord {
|
||||
ShardTask::ShardTask() : categories(1) {}
|
||||
ShardTaskList::ShardTaskList() : categories(1) {}
|
||||
|
||||
ShardTask::ShardTask(const ShardTask &other)
|
||||
: categories(other.categories), permutation_(other.permutation_), task_list_(other.task_list_) {}
|
||||
ShardTaskList::ShardTaskList(const ShardTaskList &other)
|
||||
: categories(other.categories),
|
||||
permutation_(other.permutation_),
|
||||
sample_ids_(other.sample_ids_),
|
||||
task_list_(other.task_list_) {}
|
||||
|
||||
ShardTask &ShardTask::operator=(const ShardTask &other) {
|
||||
ShardTask tmp(other);
|
||||
ShardTaskList &ShardTaskList::operator=(const ShardTaskList &other) {
|
||||
ShardTaskList tmp(other);
|
||||
std::swap(categories, tmp.categories);
|
||||
permutation_.swap(tmp.permutation_);
|
||||
sample_ids_.swap(tmp.sample_ids_);
|
||||
task_list_.swap(tmp.task_list_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
void ShardTask::MakePerm() {
|
||||
permutation_ = std::vector<int>(task_list_.size());
|
||||
for (uint32_t i = 0; i < task_list_.size(); i++) {
|
||||
void ShardTaskList::InitSampleIds() {
|
||||
// no-op if there already exists sample ids. Do not clobber previous list
|
||||
if (sample_ids_.empty()) {
|
||||
sample_ids_ = std::vector<int>(task_list_.size());
|
||||
for (int i = 0; i < task_list_.size(); i++) sample_ids_[i] = i;
|
||||
}
|
||||
}
|
||||
|
||||
void ShardTaskList::MakePerm() {
|
||||
size_t perm_size = sample_ids_.size();
|
||||
permutation_ = std::vector<int>(perm_size);
|
||||
for (uint32_t i = 0; i < perm_size; i++) {
|
||||
permutation_[i] = static_cast<int>(i);
|
||||
}
|
||||
}
|
||||
|
||||
void ShardTask::PopBack() { task_list_.pop_back(); }
|
||||
// Swap the new_tasks with orig_tasks
|
||||
void ShardTaskList::TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_tasks) {
|
||||
// When swapping, if the orig_tasks contains fields that need to be preserved after the swap, then swapping with a
|
||||
// new_tasks that does not have those fields will result in clobbering/losing the data after the swap.
|
||||
// The task_list_ should not be lost/clobbered.
|
||||
new_tasks.task_list_ = std::move(orig_tasks.task_list_);
|
||||
|
||||
uint32_t ShardTask::Size() const { return static_cast<uint32_t>(task_list_.size()); }
|
||||
// Now, it's safe to drive the swap.
|
||||
std::swap(orig_tasks, new_tasks);
|
||||
}
|
||||
|
||||
uint32_t ShardTask::SizeOfRows() const {
|
||||
void ShardTaskList::PopBack() { task_list_.pop_back(); }
|
||||
|
||||
uint32_t ShardTaskList::Size() const { return static_cast<uint32_t>(task_list_.size()); }
|
||||
|
||||
uint32_t ShardTaskList::SizeOfRows() const {
|
||||
if (task_list_.size() == 0) return static_cast<uint32_t>(0);
|
||||
|
||||
// 1 task is 1 page
|
||||
auto sum_num_rows = [](int x, std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> y) {
|
||||
return x + std::get<2>(y)[0];
|
||||
};
|
||||
auto sum_num_rows = [](int x, ShardTask y) { return x + std::get<2>(y)[0]; };
|
||||
uint32_t nRows = std::accumulate(task_list_.begin(), task_list_.end(), 0, sum_num_rows);
|
||||
return nRows;
|
||||
}
|
||||
|
||||
std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::GetTaskByID(size_t id) {
|
||||
ShardTask &ShardTaskList::GetTaskByID(size_t id) {
|
||||
MS_ASSERT(id < task_list_.size());
|
||||
return task_list_[id];
|
||||
}
|
||||
|
||||
std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::GetRandomTask() {
|
||||
int ShardTaskList::GetTaskSampleByID(size_t id) {
|
||||
MS_ASSERT(id < sample_ids_.size());
|
||||
return sample_ids_[id];
|
||||
}
|
||||
|
||||
int ShardTaskList::GetRandomTaskID() {
|
||||
std::mt19937 gen = mindspore::dataset::GetRandomDevice();
|
||||
std::uniform_int_distribution<> dis(0, task_list_.size() - 1);
|
||||
return dis(gen);
|
||||
}
|
||||
|
||||
ShardTask &ShardTaskList::GetRandomTask() {
|
||||
std::mt19937 gen = mindspore::dataset::GetRandomDevice();
|
||||
std::uniform_int_distribution<> dis(0, task_list_.size() - 1);
|
||||
return task_list_[dis(gen)];
|
||||
}
|
||||
|
||||
ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements,
|
||||
int64_t num_samples) {
|
||||
ShardTask res;
|
||||
ShardTaskList ShardTaskList::Combine(std::vector<ShardTaskList> &category_tasks, bool replacement, int64_t num_elements,
|
||||
int64_t num_samples) {
|
||||
ShardTaskList res;
|
||||
if (category_tasks.empty()) return res;
|
||||
auto total_categories = category_tasks.size();
|
||||
res.categories = static_cast<uint32_t>(total_categories);
|
||||
|
@ -107,6 +140,7 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
} // namespace mindrecord
|
Loading…
Reference in New Issue