MindRecord becomes a mappable dataset

updates

review touches

ci fixes

ci fix

ci fixin

ci fix

review updates

further cleanup

updates

updates

update

update comment
This commit is contained in:
Jamie Nisbet 2021-04-07 16:16:02 -04:00
parent dbb79b3e49
commit 515c936b85
24 changed files with 495 additions and 248 deletions

View File

@ -25,7 +25,7 @@
#include "minddata/dataset/include/constants.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/util/log_adapter.h"
@ -67,9 +67,13 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) {
if (build_num_padded_ > 0) {
sample_json = ToJson(build_sample_);
}
new_mind_record_op = std::make_shared<MindRecordOp>(
build_num_mind_record_workers_, build_dataset_file_, build_load_dataset_, build_op_connector_queue_size_,
build_columns_to_load_, build_operators_, build_num_padded_, sample_json, build_sample_bytes_);
std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>();
new_mind_record_op =
std::make_shared<MindRecordOp>(build_num_mind_record_workers_, build_dataset_file_, build_load_dataset_,
build_op_connector_queue_size_, build_columns_to_load_, build_operators_,
build_num_padded_, sample_json, build_sample_bytes_, std::move(shard_reader));
RETURN_IF_NOT_OK(new_mind_record_op->Init());
*ptr = std::move(new_mind_record_op);
@ -110,8 +114,10 @@ mindrecord::json MindRecordOp::Builder::ToJson(const py::handle &obj) {
MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector<std::string> dataset_file, bool load_dataset,
int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load,
const std::vector<std::shared_ptr<ShardOperator>> &operators, int64_t num_padded,
const mindrecord::json &sample_json, const std::map<std::string, std::string> &sample_bytes)
: MappableLeafOp(num_mind_record_workers, op_connector_queue_size, std::make_shared<SequentialSamplerRT>(0, 0)),
const mindrecord::json &sample_json, const std::map<std::string, std::string> &sample_bytes,
std::unique_ptr<ShardReader> shard_reader)
: MappableLeafOp(num_mind_record_workers, op_connector_queue_size,
std::make_shared<MindRecordSamplerRT>(shard_reader.get())),
dataset_file_(dataset_file),
load_dataset_(load_dataset),
columns_to_load_(columns_to_load),
@ -120,7 +126,8 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector<std::str
ended_worker_(0),
num_padded_(num_padded),
sample_json_(sample_json),
sample_bytes_(sample_bytes) {
sample_bytes_(sample_bytes),
shard_reader_(std::move(shard_reader)) {
io_block_queues_.Init(num_workers_, op_connector_queue_size);
epoch_sync_flag_ = true; // MindRecordOp needs to turn this flag on, otherwise, calling ShuffleTask() before all
// tasks are consumed by the worker threads would cause problem.
@ -128,7 +135,6 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector<std::str
// Private helper method to encapsulate some common construction/reset tasks
Status MindRecordOp::Init() {
shard_reader_ = std::make_unique<ShardReader>();
auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_,
num_padded_);
@ -363,9 +369,6 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint
Status MindRecordOp::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
RETURN_IF_NOT_OK(MappableLeafOp::Reset()); // Call our super class reset first.
shard_reader_->ShuffleTask();
return Status::OK();
}

View File

@ -140,7 +140,8 @@ class MindRecordOp : public MappableLeafOp {
MindRecordOp(int32_t num_mind_record_workers, std::vector<std::string> dataset_file, bool load_dataset,
int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load,
const std::vector<std::shared_ptr<ShardOperator>> &operators, int64_t num_padded_,
const mindrecord::json &sample_json, const std::map<std::string, std::string> &sample_bytes_);
const mindrecord::json &sample_json, const std::map<std::string, std::string> &sample_bytes_,
std::unique_ptr<ShardReader> shard_reader);
// Destructor
~MindRecordOp() override;

View File

@ -10,6 +10,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES
subset_random_sampler.cc
subset_sampler.cc
weighted_random_sampler.cc
mind_record_sampler.cc
)
if(ENABLE_PYTHON)

View File

@ -0,0 +1,86 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h"
#include <algorithm>
#include <limits>
#include <memory>
#include "minddata/mindrecord/include/shard_reader.h"
namespace mindspore {
namespace dataset {
MindRecordSamplerRT::MindRecordSamplerRT(mindrecord::ShardReader *shard_reader, int64_t samples_per_tensor)
: SamplerRT(0, samples_per_tensor), next_id_(0), shard_reader_(shard_reader) {}
Status MindRecordSamplerRT::GetNextSample(TensorRow *out) {
if (next_id_ > num_samples_) {
RETURN_STATUS_UNEXPECTED("MindRecordSampler Internal Error");
} else if (next_id_ == num_samples_) {
(*out) = TensorRow(TensorRow::kFlagEOE);
} else {
std::shared_ptr<Tensor> sampleIdsTensor;
int64_t last_id = std::min(samples_per_tensor_ + next_id_, num_samples_);
RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIdsTensor, last_id - next_id_));
auto id_ptr = sampleIdsTensor->begin<int64_t>();
for (int64_t i = 0; i < (last_id - next_id_); i++) {
*(id_ptr + static_cast<ptrdiff_t>(i)) = (*sample_ids_)[i];
}
next_id_ = last_id;
(*out) = {sampleIdsTensor};
}
return Status::OK();
}
Status MindRecordSamplerRT::InitSampler() {
sample_ids_ = shard_reader_->GetSampleIds();
if (!sample_ids_) {
// Note, sample_ids_.empty() is okay and will just give no sample ids.
RETURN_STATUS_UNEXPECTED("ShardReader did not provide a valid sample ids vector via MindRecordSamplerRT");
}
// Usually, the num samples is given from the user interface. In our case, that data is in mindrecord.
// Mindrecord already created the sample ids at this point, so the num samples is the size of the sampled id list.
num_samples_ = sample_ids_->size();
return Status::OK();
}
Status MindRecordSamplerRT::ResetSampler() {
// drive the shard reader reshuffle tasks to redo the sampling for another epoch
next_id_ = 0;
shard_reader_->ShuffleTask();
return Status::OK();
}
void MindRecordSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
out << "\nSampler: MindRecordSampler";
if (show_all) {
// Call the super class for displaying any common detailed info
SamplerRT::SamplerPrint(out, show_all);
// Then add our own info if any
}
}
Status MindRecordSamplerRT::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sampler_name"] = "MindRecordSampler";
*out_json = args;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,65 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_MINDRECORD_SAMPLER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_MINDRECORD_SAMPLER_H_
#include <limits>
#include <memory>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/mindrecord/include/shard_reader.h"
namespace mindspore {
namespace dataset {
class MindRecordSamplerRT : public SamplerRT {
public:
// Constructor
// @param shard_reader - shard_reader
// @param int64_t samples_per_tensor - Num of Sampler Ids to fetch via 1 GetNextSample call
MindRecordSamplerRT(mindrecord::ShardReader *shard_reader,
int64_t samples_per_tensor = std::numeric_limits<int64_t>::max());
// Destructor.
~MindRecordSamplerRT() = default;
// Op calls this to get next set of sampleIds
// @param out - Tensor of sample ids to be returned to caller
// @return Status The status code returned
Status GetNextSample(TensorRow *out) override;
// meant to be called by base class or python
Status InitSampler() override;
// for next epoch of sampleIds
// @return Status The status code returned
Status ResetSampler() override;
void SamplerPrint(std::ostream &out, bool show_all) const override;
/// \brief Get the arguments of node
/// \param[out] out_json JSON string of all attributes
/// \return Status of the function
Status to_json(nlohmann::json *out_json) override;
private:
mindrecord::ShardReader *shard_reader_; // back pointer to the shard reader
const std::vector<int> *sample_ids_; // read-only back pointer into mind record sampler ids
int64_t next_id_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_MINDRECORD_SAMPLER_H_

View File

@ -23,6 +23,7 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
@ -155,17 +156,19 @@ Status MindDataNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
RETURN_IF_NOT_OK(BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_));
std::shared_ptr<MindRecordOp> mindrecord_op;
std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>();
// If pass a string to MindData(), it will be treated as a pattern to search for matched files,
// else if pass a vector to MindData(), it will be treated as specified files to be read
if (search_for_pattern_) {
std::vector<std::string> dataset_file_vec_ = {dataset_file_};
mindrecord_op =
std::make_shared<MindRecordOp>(num_workers_, dataset_file_vec_, search_for_pattern_, connector_que_size_,
columns_list_, operators_, num_padded_, padded_sample_, sample_bytes_);
mindrecord_op = std::make_shared<MindRecordOp>(num_workers_, dataset_file_vec_, search_for_pattern_,
connector_que_size_, columns_list_, operators_, num_padded_,
padded_sample_, sample_bytes_, std::move(shard_reader));
} else {
mindrecord_op =
std::make_shared<MindRecordOp>(num_workers_, dataset_files_, search_for_pattern_, connector_que_size_,
columns_list_, operators_, num_padded_, padded_sample_, sample_bytes_);
mindrecord_op = std::make_shared<MindRecordOp>(num_workers_, dataset_files_, search_for_pattern_,
connector_que_size_, columns_list_, operators_, num_padded_,
padded_sample_, sample_bytes_, std::move(shard_reader));
}
RETURN_IF_NOT_OK(mindrecord_op->Init());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -46,7 +46,7 @@ class __attribute__((visibility("default"))) ShardCategory : public ShardOperato
bool GetReplacement() const { return replacement_; }
MSRStatus Execute(ShardTask &tasks) override;
MSRStatus Execute(ShardTaskList &tasks) override;
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -39,7 +39,7 @@ class __attribute__((visibility("default"))) ShardDistributedSample : public Sha
~ShardDistributedSample() override{};
MSRStatus PreExecute(ShardTask &tasks) override;
MSRStatus PreExecute(ShardTaskList &tasks) override;
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;
@ -47,7 +47,7 @@ class __attribute__((visibility("default"))) ShardDistributedSample : public Sha
bool shuffle_;
int no_of_padded_samples_;
bool first_epoch_; // check (num_sample + num_padded) % num_shards == 0 in first epoch
ShardTask task_; // maintain the input tasks in first epoch
ShardTaskList task_; // maintain the input tasks in first epoch
};
} // namespace mindrecord
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,7 +18,7 @@
#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_
#include <memory>
#include "minddata/mindrecord/include/shard_task.h"
#include "minddata/mindrecord/include/shard_task_list.h"
namespace mindspore {
namespace mindrecord {
@ -26,7 +26,7 @@ class __attribute__((visibility("default"))) ShardOperator {
public:
virtual ~ShardOperator() = default;
MSRStatus operator()(ShardTask &tasks) {
MSRStatus operator()(ShardTaskList &tasks) {
if (SUCCESS != this->PreExecute(tasks)) {
return FAILED;
}
@ -47,11 +47,11 @@ class __attribute__((visibility("default"))) ShardOperator {
virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; }
virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; }
virtual MSRStatus PreExecute(ShardTaskList &tasks) { return SUCCESS; }
virtual MSRStatus Execute(ShardTask &tasks) = 0;
virtual MSRStatus Execute(ShardTaskList &tasks) = 0;
virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; }
virtual MSRStatus SufExecute(ShardTaskList &tasks) { return SUCCESS; }
virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; }

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -38,7 +38,7 @@ class __attribute__((visibility("default"))) ShardPkSample : public ShardCategor
~ShardPkSample() override{};
MSRStatus SufExecute(ShardTask &tasks) override;
MSRStatus SufExecute(ShardTaskList &tasks) override;
int64_t GetNumSamples() const { return num_samples_; }

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -210,6 +210,9 @@ class API_PUBLIC ShardReader {
/// \brief get the size of blob data
MSRStatus GetTotalBlobSize(int64_t *total_blob_size);
/// \brief get a read-only ptr to the sampled ids for this epoch
const std::vector<int> *GetSampleIds();
protected:
/// \brief sqlite call back function
static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names);
@ -322,7 +325,7 @@ class API_PUBLIC ShardReader {
std::vector<std::string> selected_columns_; // columns which will be read
std::map<string, uint64_t> column_schema_id_; // column-schema map
std::vector<std::shared_ptr<ShardOperator>> operators_; // data operators, including shuffle, sample and category
ShardTask tasks_; // shard task
ShardTaskList tasks_; // shard task list
std::mutex shard_locker_; // locker of shard
// flags
@ -339,7 +342,7 @@ class API_PUBLIC ShardReader {
std::mutex mtx_delivery_; // locker for delivery
std::condition_variable cv_delivery_; // conditional variable for delivery
std::condition_variable cv_iterator_; // conditional variable for iterator
std::atomic<int> task_id_; // task ID which is working
std::atomic<int> sample_id_position_; // index into the sample ids vector for the current sample id
std::atomic<int> deliver_id_; // delivery ID which is picked up by iterator
// map of delivery
std::unordered_map<int, std::shared_ptr<std::vector<std::tuple<std::vector<uint8_t>, json>>>> delivery_map_;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -40,11 +40,11 @@ class __attribute__((visibility("default"))) ShardSample : public ShardOperator
~ShardSample() override{};
MSRStatus Execute(ShardTask &tasks) override;
MSRStatus Execute(ShardTaskList &tasks) override;
MSRStatus UpdateTasks(ShardTask &tasks, int taking);
MSRStatus UpdateTasks(ShardTaskList &tasks, int taking);
MSRStatus SufExecute(ShardTask &tasks) override;
MSRStatus SufExecute(ShardTaskList &tasks) override;
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -33,7 +33,7 @@ class __attribute__((visibility("default"))) ShardSequentialSample : public Shar
~ShardSequentialSample() override{};
MSRStatus Execute(ShardTask &tasks) override;
MSRStatus Execute(ShardTaskList &tasks) override;
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -31,11 +31,14 @@ class __attribute__((visibility("default"))) ShardShuffle : public ShardOperator
~ShardShuffle() override{};
MSRStatus Execute(ShardTask &tasks) override;
MSRStatus Execute(ShardTaskList &tasks) override;
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;
private:
// Private helper function
MSRStatus CategoryShuffle(ShardTaskList &tasks);
uint32_t shuffle_seed_;
int64_t no_of_samples_;
bool replacement_;

View File

@ -1,109 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
#include <algorithm>
#include <iostream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "minddata/mindrecord/include/common/shard_utils.h"
namespace mindspore {
namespace mindrecord {
class __attribute__((visibility("default"))) ShardTask {
public:
ShardTask();
ShardTask(const ShardTask &task); // copy construction
ShardTask &operator=(const ShardTask &task); // assignment operator
~ShardTask() = default;
void MakePerm();
inline void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset,
const json &label);
inline void InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id,
const std::vector<uint64_t> &offset, const json &label);
inline void InsertTask(std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task);
inline void InsertTask(const uint32_t &i,
std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task);
void PopBack();
uint32_t Size() const;
uint32_t SizeOfRows() const;
std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &GetTaskByID(size_t id);
std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &GetRandomTask();
static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements,
int64_t num_samples);
inline void ResizeTask(const uint32_t &size);
uint32_t categories;
// The total sample ids which used to shuffle operation. The ids like: [0, 1, 2, 3, ..., n-1, n]
std::vector<int> permutation_;
// The data struct is as below:
// 1. TaskType: kCommonTask / kPaddedTask
// 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load)
// 3. std::vector<uint64_t>, json>> : [blob_start, blob_end], scalar_variable_fields
std::vector<std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>> task_list_;
};
inline void ShardTask::InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset,
const json &label) {
MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id
<< ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label);
}
inline void ShardTask::InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id,
const std::vector<uint64_t> &offset, const json &label) {
task_list_[i] = {task_type, std::make_tuple(shard_id, group_id), offset, label};
}
inline void ShardTask::InsertTask(std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task) {
MS_LOG(DEBUG) << "Into insert task, shard_id: " << std::get<0>(std::get<1>(task))
<< ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump()
<< ", size of task_list_: " << task_list_.size() << ".";
task_list_.push_back(std::move(task));
}
inline void ShardTask::InsertTask(const uint32_t &i,
std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task) {
task_list_[i] = std::move(task);
}
inline void ShardTask::ResizeTask(const uint32_t &size) { task_list_.resize(size); }
} // namespace mindrecord
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_

View File

@ -0,0 +1,132 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_
#include <algorithm>
#include <iostream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "minddata/mindrecord/include/common/shard_utils.h"
namespace mindspore {
namespace mindrecord {
// The data struct is as below:
// 1. TaskType: kCommonTask / kPaddedTask
// 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load)
// 3. std::vector<uint64_t>, json>> : [blob_start, blob_end], scalar_variable_fields
using ShardTask = std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>;
class __attribute__((visibility("default"))) ShardTaskList {
public:
ShardTaskList();
ShardTaskList(const ShardTaskList &task); // copy construction
ShardTaskList &operator=(const ShardTaskList &task); // assignment operator
~ShardTaskList() = default;
void InitSampleIds();
static void TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_tasks);
// Assigns the task based on task id
inline void AssignTask(ShardTaskList &sourceTasks, size_t id);
inline void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset,
const json &label);
inline void InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id,
const std::vector<uint64_t> &offset, const json &label);
inline void InsertTask(ShardTask task);
inline void InsertTask(const uint32_t &i, ShardTask task);
void MakePerm();
inline void InsertSampleId(int id);
void PopBack();
uint32_t Size() const;
uint32_t SizeOfRows() const;
ShardTask &GetTaskByID(size_t id);
ShardTask &GetRandomTask();
int GetTaskSampleByID(size_t id);
int GetRandomTaskID();
static ShardTaskList Combine(std::vector<ShardTaskList> &category_tasks, bool replacement, int64_t num_elements,
int64_t num_samples);
inline void ResizeTask(const uint32_t &size);
uint32_t categories;
std::vector<int> permutation_; // A list of ints used for shuffling sample ids
std::vector<int> sample_ids_; // The list of actual ids that were sampled
std::vector<ShardTask> task_list_; // The full list of tasks
};
inline void ShardTaskList::AssignTask(ShardTaskList &sourceTasks, size_t id) {
// Insert the sample id from the source into ourself by indexing at id position.
// Important: The task list itself does not change.
int sample_id = sourceTasks.GetTaskSampleByID(id);
MS_LOG(DEBUG) << "Insert sample id (" << sample_id << ") into task list from source task position: " << id;
sample_ids_.push_back(sample_id);
}
inline void ShardTaskList::InsertTask(TaskType task_type, int shard_id, int group_id,
const std::vector<uint64_t> &offset, const json &label) {
MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id
<< ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label);
}
inline void ShardTaskList::InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id,
const std::vector<uint64_t> &offset, const json &label) {
MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id
<< ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
task_list_[i] = {task_type, std::make_tuple(shard_id, group_id), offset, label};
}
inline void ShardTaskList::InsertTask(ShardTask task) {
MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << std::get<0>(std::get<1>(task))
<< ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump()
<< ", size of task_list_: " << task_list_.size() << ".";
task_list_.push_back(std::move(task));
}
inline void ShardTaskList::InsertTask(const uint32_t &i, ShardTask task) { task_list_[i] = std::move(task); }
inline void ShardTaskList::ResizeTask(const uint32_t &size) { task_list_.resize(size); }
} // namespace mindrecord
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -46,7 +46,7 @@ ShardReader::ShardReader()
num_padded_(0),
num_rows_(0),
total_blob_size_(0),
task_id_(0),
sample_id_position_(0),
deliver_id_(0),
lazy_load_(false),
shard_sample_count_() {}
@ -1088,9 +1088,8 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator
i++;
}
}
// Generate task list, a task will create a batch
std::vector<ShardTask> categoryTasks(categories.size());
// Generate a vector of task lists. Each catogory has a list of tasks.
std::vector<ShardTaskList> categoryTasks(categories.size());
for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) {
int category_index = 0;
for (int shard_id = 0; shard_id < shard_count_ && category_index < num_elements; ++shard_id) {
@ -1122,7 +1121,9 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator
}
}
}
tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples);
tasks_ = ShardTaskList::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples);
tasks_.InitSampleIds();
if (SUCCESS != (*category_op)(tasks_)) {
return FAILED;
}
@ -1246,6 +1247,10 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u
}
}
MS_LOG(DEBUG) << "Created initial list of tasks. There are " << tasks_.Size() << " to start with before sampling.";
tasks_.InitSampleIds();
for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) {
const auto &op = operators[operator_no];
if (std::dynamic_pointer_cast<ShardCategory>(op)) continue;
@ -1256,7 +1261,9 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u
if (tasks_.permutation_.empty()) tasks_.MakePerm();
num_rows_ = tasks_.Size();
MS_LOG(INFO) << "Total rows is " << num_rows_;
MS_LOG(INFO) << "Total rows is " << num_rows_
<< " and total amount sampled initially is: " << tasks_.sample_ids_.size();
return SUCCESS;
}
@ -1272,9 +1279,9 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
uint32_t blob_start = 0;
uint32_t blob_end = 0;
json var_fields;
// Pick up task from task list
auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]);
ShardTask task;
task = tasks_.GetTaskByID(task_id);
// check task type
auto task_type = std::get<0>(task);
@ -1354,16 +1361,16 @@ MSRStatus ShardReader::ConsumerByRow(int consumer_id) {
// Loop forever
for (;;) {
int task_id = 0;
int sample_id_pos = 0;
// Get next task ID
task_id = task_id_++;
sample_id_pos = sample_id_position_++;
// All tasks are done
if (task_id >= static_cast<int>(tasks_.Size())) {
if (sample_id_pos >= static_cast<int>(tasks_.sample_ids_.size())) {
return FAILED;
}
const auto &ret = ConsumerOneTask(task_id, consumer_id);
const auto &ret = ConsumerOneTask(tasks_.sample_ids_[sample_id_pos], consumer_id);
if (SUCCESS != ret.first) {
return FAILED;
}
@ -1372,11 +1379,13 @@ MSRStatus ShardReader::ConsumerByRow(int consumer_id) {
// otherwise, set batch data in map
{
std::unique_lock<std::mutex> lck(mtx_delivery_);
cv_delivery_.wait(lck, [task_id, this] { return interrupt_ || task_id <= deliver_id_ + kNumBatchInMap; });
cv_delivery_.wait(lck,
[sample_id_pos, this] { return interrupt_ || sample_id_pos <= deliver_id_ + kNumBatchInMap; });
if (interrupt_) {
return SUCCESS;
}
delivery_map_[task_id] = std::make_shared<std::vector<std::tuple<std::vector<uint8_t>, json>>>(std::move(batch));
delivery_map_[sample_id_pos] =
std::make_shared<std::vector<std::tuple<std::vector<uint8_t>, json>>>(std::move(batch));
}
cv_iterator_.notify_one();
}
@ -1386,7 +1395,7 @@ std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNext() {
if (interrupt_) {
return std::vector<std::tuple<std::vector<uint8_t>, json>>();
}
if (deliver_id_ >= static_cast<int>(tasks_.Size())) {
if (deliver_id_ >= static_cast<int>(tasks_.sample_ids_.size())) {
return std::vector<std::tuple<std::vector<uint8_t>, json>>();
}
@ -1458,7 +1467,7 @@ std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> Sha
void ShardReader::Reset() {
{
std::lock_guard<std::mutex> lck(mtx_delivery_);
task_id_ = 0;
sample_id_position_ = 0;
deliver_id_ = 0;
}
cv_delivery_.notify_all();
@ -1486,5 +1495,10 @@ void ShardReader::ShuffleTask() {
if (tasks_.permutation_.empty()) tasks_.MakePerm();
}
const std::vector<int> *ShardReader::GetSampleIds() {
// return const reference to private sample id list.
return &(this->tasks_.sample_ids_);
}
} // namespace mindrecord
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -34,7 +34,7 @@ ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elem
num_categories_(num_categories),
replacement_(replacement) {}
MSRStatus ShardCategory::Execute(ShardTask &tasks) { return SUCCESS; }
MSRStatus ShardCategory::Execute(ShardTaskList &tasks) { return SUCCESS; }
int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (dataset_size == 0) return dataset_size;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -55,7 +55,7 @@ int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_
return 0;
}
MSRStatus ShardDistributedSample::PreExecute(ShardTask &tasks) {
MSRStatus ShardDistributedSample::PreExecute(ShardTaskList &tasks) {
auto total_no = tasks.Size();
if (no_of_padded_samples_ > 0 && first_epoch_) {
if (total_no % denominator_ != 0) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -37,7 +37,7 @@ ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elem
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement
}
MSRStatus ShardPkSample::SufExecute(ShardTask &tasks) {
MSRStatus ShardPkSample::SufExecute(ShardTaskList &tasks) {
if (shuffle_ == true) {
if (SUCCESS != (*shuffle_op_)(tasks)) {
return FAILED;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -80,21 +80,21 @@ int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
return 0;
}
MSRStatus ShardSample::UpdateTasks(ShardTask &tasks, int taking) {
MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) {
if (tasks.permutation_.empty()) {
ShardTask new_tasks;
int total_no = static_cast<int>(tasks.Size());
ShardTaskList new_tasks;
int total_no = static_cast<int>(tasks.sample_ids_.size());
if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) {
for (int i = 0; i < indices_.size(); ++i) {
int index = ((indices_[i] % total_no) + total_no) % total_no;
new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python
new_tasks.AssignTask(tasks, index); // different mod result between c and python
}
} else {
int count = 0;
if (nums_per_shard_.empty()) {
for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
if (no_of_samples_ != 0 && count == no_of_samples_) break;
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start
new_tasks.AssignTask(tasks, i % total_no); // rounding up. if overflow, go back to start
count++;
}
} else {
@ -102,33 +102,33 @@ MSRStatus ShardSample::UpdateTasks(ShardTask &tasks, int taking) {
size_t i = partition_id_ - 1 >= 0 ? nums_per_shard_[partition_id_ - 1] : 0;
for (; i < nums_per_shard_[partition_id_]; i++) {
if (no_of_samples_ != 0 && count == no_of_samples_) break;
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no));
new_tasks.AssignTask(tasks, i % total_no);
count++;
}
}
}
std::swap(tasks, new_tasks);
ShardTaskList::TaskListSwap(tasks, new_tasks);
} else {
ShardTask new_tasks;
if (taking > static_cast<int>(tasks.permutation_.size())) {
ShardTaskList new_tasks;
if (taking > static_cast<int>(tasks.sample_ids_.size())) {
return FAILED;
}
int total_no = static_cast<int>(tasks.permutation_.size());
int count = 0;
for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
if (no_of_samples_ != 0 && count == no_of_samples_) break;
new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no]));
new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]);
count++;
}
std::swap(tasks, new_tasks);
ShardTaskList::TaskListSwap(tasks, new_tasks);
}
return SUCCESS;
}
MSRStatus ShardSample::Execute(ShardTask &tasks) {
MSRStatus ShardSample::Execute(ShardTaskList &tasks) {
if (offset_ != -1) {
int64_t old_v = 0;
int num_rows_ = static_cast<int>(tasks.Size());
int num_rows_ = static_cast<int>(tasks.sample_ids_.size());
for (int x = 0; x < denominator_; x++) {
int samples_per_buffer_ = (num_rows_ + offset_) / denominator_;
int remainder = (num_rows_ + offset_) % denominator_;
@ -140,8 +140,7 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) {
}
}
int no_of_categories = static_cast<int>(tasks.categories);
int total_no = static_cast<int>(tasks.Size()); // make sure task_size
int total_no = static_cast<int>(tasks.sample_ids_.size());
int taking = 0;
if (sampler_type_ == kCustomTopNSampler) { // non sharding case constructor #1
no_of_samples_ = std::min(no_of_samples_, total_no);
@ -167,7 +166,7 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) {
return UpdateTasks(tasks, taking);
}
MSRStatus ShardSample::SufExecute(ShardTask &tasks) {
MSRStatus ShardSample::SufExecute(ShardTaskList &tasks) {
if (sampler_type_ == kSubsetRandomSampler) {
if (SUCCESS != (*shuffle_op_)(tasks)) {
return FAILED;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -38,9 +38,9 @@ int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_c
return std::min(static_cast<int64_t>(no_of_samples_), dataset_size);
}
MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) {
int64_t total_no = static_cast<int64_t>(tasks.Size());
MSRStatus ShardSequentialSample::Execute(ShardTaskList &tasks) {
int64_t taking;
int64_t total_no = static_cast<int64_t>(tasks.sample_ids_.size());
if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) {
taking = total_no;
} else if (per_ > kEpsilon && per_ <= 1.0f) {
@ -50,22 +50,22 @@ MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) {
}
if (tasks.permutation_.empty()) {
ShardTask new_tasks;
ShardTaskList new_tasks;
total_no = static_cast<int64_t>(tasks.Size());
for (size_t i = offset_; i < taking + offset_; ++i) {
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no));
new_tasks.AssignTask(tasks, i % total_no);
}
std::swap(tasks, new_tasks);
ShardTaskList::TaskListSwap(tasks, new_tasks);
} else { // shuffled
ShardTask new_tasks;
ShardTaskList new_tasks;
if (taking > static_cast<int64_t>(tasks.permutation_.size())) {
return FAILED;
}
total_no = static_cast<int64_t>(tasks.permutation_.size());
for (size_t i = offset_; i < taking + offset_; ++i) {
new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no]));
new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]);
}
std::swap(tasks, new_tasks);
ShardTaskList::TaskListSwap(tasks, new_tasks);
}
return SUCCESS;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -42,7 +42,31 @@ int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_);
}
MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
MSRStatus ShardShuffle::CategoryShuffle(ShardTaskList &tasks) {
uint32_t individual_size;
individual_size = tasks.sample_ids_.size() / tasks.categories;
std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size));
for (uint32_t i = 0; i < tasks.categories; i++) {
for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast<int>(j);
std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_));
}
tasks.permutation_.clear(); // Jamie replace this we setting flag to false or something
for (uint32_t j = 0; j < individual_size; j++) {
for (uint32_t i = 0; i < tasks.categories; i++) {
tasks.permutation_.push_back(new_permutations[i][j] * static_cast<int>(tasks.categories) + static_cast<int>(i));
}
}
ShardTaskList new_tasks;
for (size_t i = 0; i < individual_size; ++i) {
new_tasks.AssignTask(tasks, tasks.permutation_[i]);
}
ShardTaskList::TaskListSwap(tasks, new_tasks);
return SUCCESS;
}
MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) {
if (reshuffle_each_epoch_) shuffle_seed_++;
if (tasks.categories < 1) {
return FAILED;
@ -52,43 +76,31 @@ MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
tasks.MakePerm();
}
if (replacement_ == true) {
ShardTask new_tasks;
if (no_of_samples_ == 0) {
no_of_samples_ = static_cast<int>(tasks.Size());
}
ShardTaskList new_tasks;
if (no_of_samples_ == 0) no_of_samples_ = static_cast<int>(tasks.sample_ids_.size());
if (no_of_samples_ <= 0) {
MS_LOG(ERROR) << "no_of_samples need to be positive.";
return FAILED;
}
new_tasks.task_list_.reserve(no_of_samples_);
for (uint32_t i = 0; i < no_of_samples_; ++i) {
new_tasks.InsertTask(tasks.GetRandomTask());
new_tasks.AssignTask(tasks, tasks.GetRandomTaskID());
}
std::swap(tasks, new_tasks);
ShardTaskList::TaskListSwap(tasks, new_tasks);
} else {
std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_));
auto total_no = static_cast<int64_t>(tasks.Size());
if (no_of_samples_ > 0 && no_of_samples_ < total_no) {
ShardTask new_tasks;
for (size_t i = 0; i < no_of_samples_; ++i) {
new_tasks.InsertTask(tasks.GetTaskByID(i));
}
std::swap(tasks, new_tasks);
ShardTaskList new_tasks;
size_t samples_to_assign =
(no_of_samples_ > 0 && no_of_samples_ < total_no) ? no_of_samples_ : tasks.sample_ids_.size();
for (size_t i = 0; i < samples_to_assign; ++i) {
new_tasks.AssignTask(tasks, tasks.permutation_[i]);
}
ShardTaskList::TaskListSwap(tasks, new_tasks);
}
} else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn)
uint32_t individual_size = tasks.Size() / tasks.categories;
std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size));
for (uint32_t i = 0; i < tasks.categories; i++) {
for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast<int>(j);
std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_));
}
tasks.permutation_.clear();
for (uint32_t j = 0; j < individual_size; j++) {
for (uint32_t i = 0; i < tasks.categories; i++) {
tasks.permutation_.push_back(new_permutations[i][j] * static_cast<int>(tasks.categories) + static_cast<int>(i));
}
}
return this->CategoryShuffle(tasks);
}
return SUCCESS;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,7 +15,7 @@
*/
#include "minddata/dataset/util/random.h"
#include "minddata/mindrecord/include/shard_task.h"
#include "minddata/mindrecord/include/shard_task_list.h"
#include "utils/ms_utils.h"
#include "minddata/mindrecord/include/common/shard_utils.h"
@ -25,55 +25,88 @@ using mindspore::MsLogLevel::DEBUG;
namespace mindspore {
namespace mindrecord {
ShardTask::ShardTask() : categories(1) {}
ShardTaskList::ShardTaskList() : categories(1) {}
ShardTask::ShardTask(const ShardTask &other)
: categories(other.categories), permutation_(other.permutation_), task_list_(other.task_list_) {}
ShardTaskList::ShardTaskList(const ShardTaskList &other)
: categories(other.categories),
permutation_(other.permutation_),
sample_ids_(other.sample_ids_),
task_list_(other.task_list_) {}
ShardTask &ShardTask::operator=(const ShardTask &other) {
ShardTask tmp(other);
ShardTaskList &ShardTaskList::operator=(const ShardTaskList &other) {
ShardTaskList tmp(other);
std::swap(categories, tmp.categories);
permutation_.swap(tmp.permutation_);
sample_ids_.swap(tmp.sample_ids_);
task_list_.swap(tmp.task_list_);
return *this;
}
void ShardTask::MakePerm() {
permutation_ = std::vector<int>(task_list_.size());
for (uint32_t i = 0; i < task_list_.size(); i++) {
void ShardTaskList::InitSampleIds() {
// no-op if there already exists sample ids. Do not clobber previous list
if (sample_ids_.empty()) {
sample_ids_ = std::vector<int>(task_list_.size());
for (int i = 0; i < task_list_.size(); i++) sample_ids_[i] = i;
}
}
void ShardTaskList::MakePerm() {
size_t perm_size = sample_ids_.size();
permutation_ = std::vector<int>(perm_size);
for (uint32_t i = 0; i < perm_size; i++) {
permutation_[i] = static_cast<int>(i);
}
}
void ShardTask::PopBack() { task_list_.pop_back(); }
// Swap the new_tasks with orig_tasks
void ShardTaskList::TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_tasks) {
// When swapping, if the orig_tasks contains fields that need to be preserved after the swap, then swapping with a
// new_tasks that does not have those fields will result in clobbering/losing the data after the swap.
// The task_list_ should not be lost/clobbered.
new_tasks.task_list_ = std::move(orig_tasks.task_list_);
uint32_t ShardTask::Size() const { return static_cast<uint32_t>(task_list_.size()); }
// Now, it's safe to drive the swap.
std::swap(orig_tasks, new_tasks);
}
uint32_t ShardTask::SizeOfRows() const {
void ShardTaskList::PopBack() { task_list_.pop_back(); }
uint32_t ShardTaskList::Size() const { return static_cast<uint32_t>(task_list_.size()); }
uint32_t ShardTaskList::SizeOfRows() const {
if (task_list_.size() == 0) return static_cast<uint32_t>(0);
// 1 task is 1 page
auto sum_num_rows = [](int x, std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> y) {
return x + std::get<2>(y)[0];
};
auto sum_num_rows = [](int x, ShardTask y) { return x + std::get<2>(y)[0]; };
uint32_t nRows = std::accumulate(task_list_.begin(), task_list_.end(), 0, sum_num_rows);
return nRows;
}
std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::GetTaskByID(size_t id) {
ShardTask &ShardTaskList::GetTaskByID(size_t id) {
MS_ASSERT(id < task_list_.size());
return task_list_[id];
}
std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::GetRandomTask() {
int ShardTaskList::GetTaskSampleByID(size_t id) {
MS_ASSERT(id < sample_ids_.size());
return sample_ids_[id];
}
int ShardTaskList::GetRandomTaskID() {
std::mt19937 gen = mindspore::dataset::GetRandomDevice();
std::uniform_int_distribution<> dis(0, task_list_.size() - 1);
return dis(gen);
}
ShardTask &ShardTaskList::GetRandomTask() {
std::mt19937 gen = mindspore::dataset::GetRandomDevice();
std::uniform_int_distribution<> dis(0, task_list_.size() - 1);
return task_list_[dis(gen)];
}
ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements,
ShardTaskList ShardTaskList::Combine(std::vector<ShardTaskList> &category_tasks, bool replacement, int64_t num_elements,
int64_t num_samples) {
ShardTask res;
ShardTaskList res;
if (category_tasks.empty()) return res;
auto total_categories = category_tasks.size();
res.categories = static_cast<uint32_t>(total_categories);
@ -107,6 +140,7 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac
}
}
}
return res;
}
} // namespace mindrecord