diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc index ad85d1a01ab..1bbfbea8be9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc @@ -25,7 +25,7 @@ #include "minddata/dataset/include/constants.h" #include "minddata/dataset/core/global_context.h" -#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h" #include "minddata/dataset/engine/db_connector.h" #include "minddata/dataset/engine/execution_tree.h" #include "minddata/dataset/util/log_adapter.h" @@ -67,9 +67,13 @@ Status MindRecordOp::Builder::Build(std::shared_ptr *ptr) { if (build_num_padded_ > 0) { sample_json = ToJson(build_sample_); } - new_mind_record_op = std::make_shared( - build_num_mind_record_workers_, build_dataset_file_, build_load_dataset_, build_op_connector_queue_size_, - build_columns_to_load_, build_operators_, build_num_padded_, sample_json, build_sample_bytes_); + + std::unique_ptr shard_reader = std::make_unique(); + + new_mind_record_op = + std::make_shared(build_num_mind_record_workers_, build_dataset_file_, build_load_dataset_, + build_op_connector_queue_size_, build_columns_to_load_, build_operators_, + build_num_padded_, sample_json, build_sample_bytes_, std::move(shard_reader)); RETURN_IF_NOT_OK(new_mind_record_op->Init()); *ptr = std::move(new_mind_record_op); @@ -110,8 +114,10 @@ mindrecord::json MindRecordOp::Builder::ToJson(const py::handle &obj) { MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector dataset_file, bool load_dataset, int32_t op_connector_queue_size, const std::vector &columns_to_load, const std::vector> &operators, int64_t num_padded, - const mindrecord::json &sample_json, const std::map &sample_bytes) - : MappableLeafOp(num_mind_record_workers, op_connector_queue_size, std::make_shared(0, 0)), + const mindrecord::json &sample_json, const std::map &sample_bytes, + std::unique_ptr shard_reader) + : MappableLeafOp(num_mind_record_workers, op_connector_queue_size, + std::make_shared(shard_reader.get())), dataset_file_(dataset_file), load_dataset_(load_dataset), columns_to_load_(columns_to_load), @@ -120,7 +126,8 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector(); auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_, num_padded_); @@ -363,9 +369,6 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vectorShuffleTask(); - return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h index 0de1130a690..2efc37e0fde 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h @@ -140,7 +140,8 @@ class MindRecordOp : public MappableLeafOp { MindRecordOp(int32_t num_mind_record_workers, std::vector dataset_file, bool load_dataset, int32_t op_connector_queue_size, const std::vector &columns_to_load, const std::vector> &operators, int64_t num_padded_, - const mindrecord::json &sample_json, const std::map &sample_bytes_); + const mindrecord::json &sample_json, const std::map &sample_bytes_, + std::unique_ptr shard_reader); // Destructor ~MindRecordOp() override; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/CMakeLists.txt index 98c31aa4b76..1854d3da101 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/CMakeLists.txt @@ -10,6 +10,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES subset_random_sampler.cc subset_sampler.cc weighted_random_sampler.cc + mind_record_sampler.cc ) if(ENABLE_PYTHON) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.cc new file mode 100644 index 00000000000..3e305bbfed0 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h" + +#include +#include +#include + +#include "minddata/mindrecord/include/shard_reader.h" + +namespace mindspore { +namespace dataset { +MindRecordSamplerRT::MindRecordSamplerRT(mindrecord::ShardReader *shard_reader, int64_t samples_per_tensor) + : SamplerRT(0, samples_per_tensor), next_id_(0), shard_reader_(shard_reader) {} + +Status MindRecordSamplerRT::GetNextSample(TensorRow *out) { + if (next_id_ > num_samples_) { + RETURN_STATUS_UNEXPECTED("MindRecordSampler Internal Error"); + } else if (next_id_ == num_samples_) { + (*out) = TensorRow(TensorRow::kFlagEOE); + } else { + std::shared_ptr sampleIdsTensor; + int64_t last_id = std::min(samples_per_tensor_ + next_id_, num_samples_); + RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIdsTensor, last_id - next_id_)); + auto id_ptr = sampleIdsTensor->begin(); + for (int64_t i = 0; i < (last_id - next_id_); i++) { + *(id_ptr + static_cast(i)) = (*sample_ids_)[i]; + } + next_id_ = last_id; + + (*out) = {sampleIdsTensor}; + } + return Status::OK(); +} + +Status MindRecordSamplerRT::InitSampler() { + sample_ids_ = shard_reader_->GetSampleIds(); + + if (!sample_ids_) { + // Note, sample_ids_.empty() is okay and will just give no sample ids. + RETURN_STATUS_UNEXPECTED("ShardReader did not provide a valid sample ids vector via MindRecordSamplerRT"); + } + + // Usually, the num samples is given from the user interface. In our case, that data is in mindrecord. + // Mindrecord already created the sample ids at this point, so the num samples is the size of the sampled id list. + num_samples_ = sample_ids_->size(); + return Status::OK(); +} + +Status MindRecordSamplerRT::ResetSampler() { + // drive the shard reader reshuffle tasks to redo the sampling for another epoch + next_id_ = 0; + shard_reader_->ShuffleTask(); + return Status::OK(); +} + +void MindRecordSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { + out << "\nSampler: MindRecordSampler"; + if (show_all) { + // Call the super class for displaying any common detailed info + SamplerRT::SamplerPrint(out, show_all); + // Then add our own info if any + } +} + +Status MindRecordSamplerRT::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["sampler_name"] = "MindRecordSampler"; + *out_json = args; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h new file mode 100644 index 00000000000..4669429587f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h @@ -0,0 +1,65 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_MINDRECORD_SAMPLER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_MINDRECORD_SAMPLER_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/mindrecord/include/shard_reader.h" + +namespace mindspore { +namespace dataset { +class MindRecordSamplerRT : public SamplerRT { + public: + // Constructor + // @param shard_reader - shard_reader + // @param int64_t samples_per_tensor - Num of Sampler Ids to fetch via 1 GetNextSample call + MindRecordSamplerRT(mindrecord::ShardReader *shard_reader, + int64_t samples_per_tensor = std::numeric_limits::max()); + + // Destructor. + ~MindRecordSamplerRT() = default; + + // Op calls this to get next set of sampleIds + // @param out - Tensor of sample ids to be returned to caller + // @return Status The status code returned + Status GetNextSample(TensorRow *out) override; + + // meant to be called by base class or python + Status InitSampler() override; + + // for next epoch of sampleIds + // @return Status The status code returned + Status ResetSampler() override; + + void SamplerPrint(std::ostream &out, bool show_all) const override; + + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + + private: + mindrecord::ShardReader *shard_reader_; // back pointer to the shard reader + const std::vector *sample_ids_; // read-only back pointer into mind record sampler ids + int64_t next_id_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_MINDRECORD_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc index 624d211aa37..5858859b6c2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc @@ -23,6 +23,7 @@ #include #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h" #include "minddata/dataset/engine/opt/pass.h" #include "minddata/dataset/util/status.h" @@ -155,17 +156,19 @@ Status MindDataNode::Build(std::vector> *const node_o RETURN_IF_NOT_OK(BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_)); std::shared_ptr mindrecord_op; + std::unique_ptr shard_reader = std::make_unique(); + // If pass a string to MindData(), it will be treated as a pattern to search for matched files, // else if pass a vector to MindData(), it will be treated as specified files to be read if (search_for_pattern_) { std::vector dataset_file_vec_ = {dataset_file_}; - mindrecord_op = - std::make_shared(num_workers_, dataset_file_vec_, search_for_pattern_, connector_que_size_, - columns_list_, operators_, num_padded_, padded_sample_, sample_bytes_); + mindrecord_op = std::make_shared(num_workers_, dataset_file_vec_, search_for_pattern_, + connector_que_size_, columns_list_, operators_, num_padded_, + padded_sample_, sample_bytes_, std::move(shard_reader)); } else { - mindrecord_op = - std::make_shared(num_workers_, dataset_files_, search_for_pattern_, connector_que_size_, - columns_list_, operators_, num_padded_, padded_sample_, sample_bytes_); + mindrecord_op = std::make_shared(num_workers_, dataset_files_, search_for_pattern_, + connector_que_size_, columns_list_, operators_, num_padded_, + padded_sample_, sample_bytes_, std::move(shard_reader)); } RETURN_IF_NOT_OK(mindrecord_op->Init()); diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_category.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_category.h index 156e0ae0812..beee2b928f5 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_category.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_category.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,7 +46,7 @@ class __attribute__((visibility("default"))) ShardCategory : public ShardOperato bool GetReplacement() const { return replacement_; } - MSRStatus Execute(ShardTask &tasks) override; + MSRStatus Execute(ShardTaskList &tasks) override; int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h index 6c60faae1be..9790e000182 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,15 +39,15 @@ class __attribute__((visibility("default"))) ShardDistributedSample : public Sha ~ShardDistributedSample() override{}; - MSRStatus PreExecute(ShardTask &tasks) override; + MSRStatus PreExecute(ShardTaskList &tasks) override; int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; private: bool shuffle_; int no_of_padded_samples_; - bool first_epoch_; // check (num_sample + num_padded) % num_shards == 0 in first epoch - ShardTask task_; // maintain the input tasks in first epoch + bool first_epoch_; // check (num_sample + num_padded) % num_shards == 0 in first epoch + ShardTaskList task_; // maintain the input tasks in first epoch }; } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h index fbed2ec50fa..e7719cda9d3 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ #include -#include "minddata/mindrecord/include/shard_task.h" +#include "minddata/mindrecord/include/shard_task_list.h" namespace mindspore { namespace mindrecord { @@ -26,7 +26,7 @@ class __attribute__((visibility("default"))) ShardOperator { public: virtual ~ShardOperator() = default; - MSRStatus operator()(ShardTask &tasks) { + MSRStatus operator()(ShardTaskList &tasks) { if (SUCCESS != this->PreExecute(tasks)) { return FAILED; } @@ -47,11 +47,11 @@ class __attribute__((visibility("default"))) ShardOperator { virtual std::shared_ptr GetChildOp() { return child_op_; } - virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; } + virtual MSRStatus PreExecute(ShardTaskList &tasks) { return SUCCESS; } - virtual MSRStatus Execute(ShardTask &tasks) = 0; + virtual MSRStatus Execute(ShardTaskList &tasks) = 0; - virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; } + virtual MSRStatus SufExecute(ShardTaskList &tasks) { return SUCCESS; } virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h index 7fe81c03eb2..fecdb97905d 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ class __attribute__((visibility("default"))) ShardPkSample : public ShardCategor ~ShardPkSample() override{}; - MSRStatus SufExecute(ShardTask &tasks) override; + MSRStatus SufExecute(ShardTaskList &tasks) override; int64_t GetNumSamples() const { return num_samples_; } diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h index 769b8555c55..a6fca37506f 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -210,6 +210,9 @@ class API_PUBLIC ShardReader { /// \brief get the size of blob data MSRStatus GetTotalBlobSize(int64_t *total_blob_size); + /// \brief get a read-only ptr to the sampled ids for this epoch + const std::vector *GetSampleIds(); + protected: /// \brief sqlite call back function static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); @@ -322,7 +325,7 @@ class API_PUBLIC ShardReader { std::vector selected_columns_; // columns which will be read std::map column_schema_id_; // column-schema map std::vector> operators_; // data operators, including shuffle, sample and category - ShardTask tasks_; // shard task + ShardTaskList tasks_; // shard task list std::mutex shard_locker_; // locker of shard // flags @@ -339,7 +342,7 @@ class API_PUBLIC ShardReader { std::mutex mtx_delivery_; // locker for delivery std::condition_variable cv_delivery_; // conditional variable for delivery std::condition_variable cv_iterator_; // conditional variable for iterator - std::atomic task_id_; // task ID which is working + std::atomic sample_id_position_; // index into the sample ids vector for the current sample id std::atomic deliver_id_; // delivery ID which is picked up by iterator // map of delivery std::unordered_map, json>>>> delivery_map_; diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h index fb08765c182..6f469625df0 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,11 +40,11 @@ class __attribute__((visibility("default"))) ShardSample : public ShardOperator ~ShardSample() override{}; - MSRStatus Execute(ShardTask &tasks) override; + MSRStatus Execute(ShardTaskList &tasks) override; - MSRStatus UpdateTasks(ShardTask &tasks, int taking); + MSRStatus UpdateTasks(ShardTaskList &tasks, int taking); - MSRStatus SufExecute(ShardTask &tasks) override; + MSRStatus SufExecute(ShardTaskList &tasks) override; int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h index dd4406e768c..6b17497d53b 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ class __attribute__((visibility("default"))) ShardSequentialSample : public Shar ~ShardSequentialSample() override{}; - MSRStatus Execute(ShardTask &tasks) override; + MSRStatus Execute(ShardTaskList &tasks) override; int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h index be16eff94ed..b22b7d9f87f 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,11 +31,14 @@ class __attribute__((visibility("default"))) ShardShuffle : public ShardOperator ~ShardShuffle() override{}; - MSRStatus Execute(ShardTask &tasks) override; + MSRStatus Execute(ShardTaskList &tasks) override; int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; private: + // Private helper function + MSRStatus CategoryShuffle(ShardTaskList &tasks); + uint32_t shuffle_seed_; int64_t no_of_samples_; bool replacement_; diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h deleted file mode 100644 index 395eda3a3de..00000000000 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h +++ /dev/null @@ -1,109 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ -#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ - -#include -#include -#include -#include -#include -#include -#include "minddata/mindrecord/include/common/shard_utils.h" - -namespace mindspore { -namespace mindrecord { -class __attribute__((visibility("default"))) ShardTask { - public: - ShardTask(); - - ShardTask(const ShardTask &task); // copy construction - - ShardTask &operator=(const ShardTask &task); // assignment operator - - ~ShardTask() = default; - - void MakePerm(); - - inline void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector &offset, - const json &label); - - inline void InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id, - const std::vector &offset, const json &label); - - inline void InsertTask(std::tuple, std::vector, json> task); - - inline void InsertTask(const uint32_t &i, - std::tuple, std::vector, json> task); - - void PopBack(); - - uint32_t Size() const; - - uint32_t SizeOfRows() const; - - std::tuple, std::vector, json> &GetTaskByID(size_t id); - - std::tuple, std::vector, json> &GetRandomTask(); - - static ShardTask Combine(std::vector &category_tasks, bool replacement, int64_t num_elements, - int64_t num_samples); - - inline void ResizeTask(const uint32_t &size); - - uint32_t categories; - - // The total sample ids which used to shuffle operation. The ids like: [0, 1, 2, 3, ..., n-1, n] - std::vector permutation_; - - // The data struct is as below: - // 1. TaskType: kCommonTask / kPaddedTask - // 2. std::tuple : shard_id, group_id(fast load) / sample_id(lazy load) - // 3. std::vector, json>> : [blob_start, blob_end], scalar_variable_fields - std::vector, std::vector, json>> task_list_; -}; - -inline void ShardTask::InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector &offset, - const json &label) { - MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id - << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; - task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label); -} - -inline void ShardTask::InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id, - const std::vector &offset, const json &label) { - task_list_[i] = {task_type, std::make_tuple(shard_id, group_id), offset, label}; -} - -inline void ShardTask::InsertTask(std::tuple, std::vector, json> task) { - MS_LOG(DEBUG) << "Into insert task, shard_id: " << std::get<0>(std::get<1>(task)) - << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump() - << ", size of task_list_: " << task_list_.size() << "."; - - task_list_.push_back(std::move(task)); -} - -inline void ShardTask::InsertTask(const uint32_t &i, - std::tuple, std::vector, json> task) { - task_list_[i] = std::move(task); -} - -inline void ShardTask::ResizeTask(const uint32_t &size) { task_list_.resize(size); } -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_task_list.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_task_list.h new file mode 100644 index 00000000000..243a1a34910 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_task_list.h @@ -0,0 +1,132 @@ +/** + * Copyright 2019-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ +#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ + +#include +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/common/shard_utils.h" + +namespace mindspore { +namespace mindrecord { + +// The data struct is as below: +// 1. TaskType: kCommonTask / kPaddedTask +// 2. std::tuple : shard_id, group_id(fast load) / sample_id(lazy load) +// 3. std::vector, json>> : [blob_start, blob_end], scalar_variable_fields +using ShardTask = std::tuple, std::vector, json>; + +class __attribute__((visibility("default"))) ShardTaskList { + public: + ShardTaskList(); + + ShardTaskList(const ShardTaskList &task); // copy construction + + ShardTaskList &operator=(const ShardTaskList &task); // assignment operator + + ~ShardTaskList() = default; + + void InitSampleIds(); + + static void TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_tasks); + + // Assigns the task based on task id + inline void AssignTask(ShardTaskList &sourceTasks, size_t id); + + inline void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector &offset, + const json &label); + + inline void InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id, + const std::vector &offset, const json &label); + + inline void InsertTask(ShardTask task); + + inline void InsertTask(const uint32_t &i, ShardTask task); + + void MakePerm(); + + inline void InsertSampleId(int id); + + void PopBack(); + + uint32_t Size() const; + + uint32_t SizeOfRows() const; + + ShardTask &GetTaskByID(size_t id); + + ShardTask &GetRandomTask(); + + int GetTaskSampleByID(size_t id); + + int GetRandomTaskID(); + + static ShardTaskList Combine(std::vector &category_tasks, bool replacement, int64_t num_elements, + int64_t num_samples); + + inline void ResizeTask(const uint32_t &size); + + uint32_t categories; + + std::vector permutation_; // A list of ints used for shuffling sample ids + + std::vector sample_ids_; // The list of actual ids that were sampled + + std::vector task_list_; // The full list of tasks +}; + +inline void ShardTaskList::AssignTask(ShardTaskList &sourceTasks, size_t id) { + // Insert the sample id from the source into ourself by indexing at id position. + // Important: The task list itself does not change. + int sample_id = sourceTasks.GetTaskSampleByID(id); + MS_LOG(DEBUG) << "Insert sample id (" << sample_id << ") into task list from source task position: " << id; + sample_ids_.push_back(sample_id); +} + +inline void ShardTaskList::InsertTask(TaskType task_type, int shard_id, int group_id, + const std::vector &offset, const json &label) { + MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id + << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; + task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label); +} + +inline void ShardTaskList::InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id, + const std::vector &offset, const json &label) { + MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id + << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; + task_list_[i] = {task_type, std::make_tuple(shard_id, group_id), offset, label}; +} + +inline void ShardTaskList::InsertTask(ShardTask task) { + MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << std::get<0>(std::get<1>(task)) + << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump() + << ", size of task_list_: " << task_list_.size() << "."; + + task_list_.push_back(std::move(task)); +} + +inline void ShardTaskList::InsertTask(const uint32_t &i, ShardTask task) { task_list_[i] = std::move(task); } + +inline void ShardTaskList::ResizeTask(const uint32_t &size) { task_list_.resize(size); } +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_TASK_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc index 8d771d219d8..f86da22ff58 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,7 +46,7 @@ ShardReader::ShardReader() num_padded_(0), num_rows_(0), total_blob_size_(0), - task_id_(0), + sample_id_position_(0), deliver_id_(0), lazy_load_(false), shard_sample_count_() {} @@ -1088,9 +1088,8 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::shared_ptr categoryTasks(categories.size()); + // Generate a vector of task lists. Each catogory has a list of tasks. + std::vector categoryTasks(categories.size()); for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) { int category_index = 0; for (int shard_id = 0; shard_id < shard_count_ && category_index < num_elements; ++shard_id) { @@ -1122,7 +1121,9 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::shared_ptrGetReplacement(), num_elements, num_samples); + tasks_ = ShardTaskList::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples); + + tasks_.InitSampleIds(); if (SUCCESS != (*category_op)(tasks_)) { return FAILED; } @@ -1246,6 +1247,10 @@ MSRStatus ShardReader::CreateTasks(const std::vector(op)) continue; @@ -1256,7 +1261,9 @@ MSRStatus ShardReader::CreateTasks(const std::vector(task); @@ -1354,16 +1361,16 @@ MSRStatus ShardReader::ConsumerByRow(int consumer_id) { // Loop forever for (;;) { - int task_id = 0; + int sample_id_pos = 0; // Get next task ID - task_id = task_id_++; + sample_id_pos = sample_id_position_++; // All tasks are done - if (task_id >= static_cast(tasks_.Size())) { + if (sample_id_pos >= static_cast(tasks_.sample_ids_.size())) { return FAILED; } - const auto &ret = ConsumerOneTask(task_id, consumer_id); + const auto &ret = ConsumerOneTask(tasks_.sample_ids_[sample_id_pos], consumer_id); if (SUCCESS != ret.first) { return FAILED; } @@ -1372,11 +1379,13 @@ MSRStatus ShardReader::ConsumerByRow(int consumer_id) { // otherwise, set batch data in map { std::unique_lock lck(mtx_delivery_); - cv_delivery_.wait(lck, [task_id, this] { return interrupt_ || task_id <= deliver_id_ + kNumBatchInMap; }); + cv_delivery_.wait(lck, + [sample_id_pos, this] { return interrupt_ || sample_id_pos <= deliver_id_ + kNumBatchInMap; }); if (interrupt_) { return SUCCESS; } - delivery_map_[task_id] = std::make_shared, json>>>(std::move(batch)); + delivery_map_[sample_id_pos] = + std::make_shared, json>>>(std::move(batch)); } cv_iterator_.notify_one(); } @@ -1386,7 +1395,7 @@ std::vector, json>> ShardReader::GetNext() { if (interrupt_) { return std::vector, json>>(); } - if (deliver_id_ >= static_cast(tasks_.Size())) { + if (deliver_id_ >= static_cast(tasks_.sample_ids_.size())) { return std::vector, json>>(); } @@ -1458,7 +1467,7 @@ std::vector>, pybind11::object>> Sha void ShardReader::Reset() { { std::lock_guard lck(mtx_delivery_); - task_id_ = 0; + sample_id_position_ = 0; deliver_id_ = 0; } cv_delivery_.notify_all(); @@ -1486,5 +1495,10 @@ void ShardReader::ShuffleTask() { if (tasks_.permutation_.empty()) tasks_.MakePerm(); } +const std::vector *ShardReader::GetSampleIds() { + // return const reference to private sample id list. + return &(this->tasks_.sample_ids_); +} + } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_category.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_category.cc index eb1428a2ade..33c77e21dc9 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_category.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_category.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elem num_categories_(num_categories), replacement_(replacement) {} -MSRStatus ShardCategory::Execute(ShardTask &tasks) { return SUCCESS; } +MSRStatus ShardCategory::Execute(ShardTaskList &tasks) { return SUCCESS; } int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { if (dataset_size == 0) return dataset_size; diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc index 7ffea620986..4b4d57ee407 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -55,7 +55,7 @@ int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_ return 0; } -MSRStatus ShardDistributedSample::PreExecute(ShardTask &tasks) { +MSRStatus ShardDistributedSample::PreExecute(ShardTaskList &tasks) { auto total_no = tasks.Size(); if (no_of_padded_samples_ > 0 && first_epoch_) { if (total_no % denominator_ != 0) { diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc index ed4cf019dc7..aa01204d29a 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elem shuffle_op_ = std::make_shared(seed, kShuffleSample); // do shuffle and replacement } -MSRStatus ShardPkSample::SufExecute(ShardTask &tasks) { +MSRStatus ShardPkSample::SufExecute(ShardTaskList &tasks) { if (shuffle_ == true) { if (SUCCESS != (*shuffle_op_)(tasks)) { return FAILED; diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc index 73bf6fae526..152cf67b7d1 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -80,21 +80,21 @@ int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } -MSRStatus ShardSample::UpdateTasks(ShardTask &tasks, int taking) { +MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) { if (tasks.permutation_.empty()) { - ShardTask new_tasks; - int total_no = static_cast(tasks.Size()); + ShardTaskList new_tasks; + int total_no = static_cast(tasks.sample_ids_.size()); if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) { for (int i = 0; i < indices_.size(); ++i) { int index = ((indices_[i] % total_no) + total_no) % total_no; - new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python + new_tasks.AssignTask(tasks, index); // different mod result between c and python } } else { int count = 0; if (nums_per_shard_.empty()) { for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { if (no_of_samples_ != 0 && count == no_of_samples_) break; - new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start + new_tasks.AssignTask(tasks, i % total_no); // rounding up. if overflow, go back to start count++; } } else { @@ -102,33 +102,33 @@ MSRStatus ShardSample::UpdateTasks(ShardTask &tasks, int taking) { size_t i = partition_id_ - 1 >= 0 ? nums_per_shard_[partition_id_ - 1] : 0; for (; i < nums_per_shard_[partition_id_]; i++) { if (no_of_samples_ != 0 && count == no_of_samples_) break; - new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); + new_tasks.AssignTask(tasks, i % total_no); count++; } } } - std::swap(tasks, new_tasks); + ShardTaskList::TaskListSwap(tasks, new_tasks); } else { - ShardTask new_tasks; - if (taking > static_cast(tasks.permutation_.size())) { + ShardTaskList new_tasks; + if (taking > static_cast(tasks.sample_ids_.size())) { return FAILED; } int total_no = static_cast(tasks.permutation_.size()); int count = 0; for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { if (no_of_samples_ != 0 && count == no_of_samples_) break; - new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); + new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]); count++; } - std::swap(tasks, new_tasks); + ShardTaskList::TaskListSwap(tasks, new_tasks); } return SUCCESS; } -MSRStatus ShardSample::Execute(ShardTask &tasks) { +MSRStatus ShardSample::Execute(ShardTaskList &tasks) { if (offset_ != -1) { int64_t old_v = 0; - int num_rows_ = static_cast(tasks.Size()); + int num_rows_ = static_cast(tasks.sample_ids_.size()); for (int x = 0; x < denominator_; x++) { int samples_per_buffer_ = (num_rows_ + offset_) / denominator_; int remainder = (num_rows_ + offset_) % denominator_; @@ -140,8 +140,7 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) { } } int no_of_categories = static_cast(tasks.categories); - int total_no = static_cast(tasks.Size()); // make sure task_size - + int total_no = static_cast(tasks.sample_ids_.size()); int taking = 0; if (sampler_type_ == kCustomTopNSampler) { // non sharding case constructor #1 no_of_samples_ = std::min(no_of_samples_, total_no); @@ -167,7 +166,7 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) { return UpdateTasks(tasks, taking); } -MSRStatus ShardSample::SufExecute(ShardTask &tasks) { +MSRStatus ShardSample::SufExecute(ShardTaskList &tasks) { if (sampler_type_ == kSubsetRandomSampler) { if (SUCCESS != (*shuffle_op_)(tasks)) { return FAILED; diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc index ade1d496f5d..c7f57f43cd2 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,9 +38,9 @@ int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_c return std::min(static_cast(no_of_samples_), dataset_size); } -MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) { - int64_t total_no = static_cast(tasks.Size()); +MSRStatus ShardSequentialSample::Execute(ShardTaskList &tasks) { int64_t taking; + int64_t total_no = static_cast(tasks.sample_ids_.size()); if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { taking = total_no; } else if (per_ > kEpsilon && per_ <= 1.0f) { @@ -50,22 +50,22 @@ MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) { } if (tasks.permutation_.empty()) { - ShardTask new_tasks; + ShardTaskList new_tasks; total_no = static_cast(tasks.Size()); for (size_t i = offset_; i < taking + offset_; ++i) { - new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); + new_tasks.AssignTask(tasks, i % total_no); } - std::swap(tasks, new_tasks); + ShardTaskList::TaskListSwap(tasks, new_tasks); } else { // shuffled - ShardTask new_tasks; + ShardTaskList new_tasks; if (taking > static_cast(tasks.permutation_.size())) { return FAILED; } total_no = static_cast(tasks.permutation_.size()); for (size_t i = offset_; i < taking + offset_; ++i) { - new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); + new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]); } - std::swap(tasks, new_tasks); + ShardTaskList::TaskListSwap(tasks, new_tasks); } return SUCCESS; } diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc index 3dcf9c95261..92d1addd00e 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,7 +42,31 @@ int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) { return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_); } -MSRStatus ShardShuffle::Execute(ShardTask &tasks) { +MSRStatus ShardShuffle::CategoryShuffle(ShardTaskList &tasks) { + uint32_t individual_size; + individual_size = tasks.sample_ids_.size() / tasks.categories; + std::vector> new_permutations(tasks.categories, std::vector(individual_size)); + for (uint32_t i = 0; i < tasks.categories; i++) { + for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast(j); + std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); + } + tasks.permutation_.clear(); // Jamie replace this we setting flag to false or something + for (uint32_t j = 0; j < individual_size; j++) { + for (uint32_t i = 0; i < tasks.categories; i++) { + tasks.permutation_.push_back(new_permutations[i][j] * static_cast(tasks.categories) + static_cast(i)); + } + } + + ShardTaskList new_tasks; + for (size_t i = 0; i < individual_size; ++i) { + new_tasks.AssignTask(tasks, tasks.permutation_[i]); + } + ShardTaskList::TaskListSwap(tasks, new_tasks); + + return SUCCESS; +} + +MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) { if (reshuffle_each_epoch_) shuffle_seed_++; if (tasks.categories < 1) { return FAILED; @@ -52,43 +76,31 @@ MSRStatus ShardShuffle::Execute(ShardTask &tasks) { tasks.MakePerm(); } if (replacement_ == true) { - ShardTask new_tasks; - if (no_of_samples_ == 0) { - no_of_samples_ = static_cast(tasks.Size()); - } + ShardTaskList new_tasks; + if (no_of_samples_ == 0) no_of_samples_ = static_cast(tasks.sample_ids_.size()); if (no_of_samples_ <= 0) { MS_LOG(ERROR) << "no_of_samples need to be positive."; return FAILED; } new_tasks.task_list_.reserve(no_of_samples_); for (uint32_t i = 0; i < no_of_samples_; ++i) { - new_tasks.InsertTask(tasks.GetRandomTask()); + new_tasks.AssignTask(tasks, tasks.GetRandomTaskID()); } - std::swap(tasks, new_tasks); + + ShardTaskList::TaskListSwap(tasks, new_tasks); } else { std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); auto total_no = static_cast(tasks.Size()); - if (no_of_samples_ > 0 && no_of_samples_ < total_no) { - ShardTask new_tasks; - for (size_t i = 0; i < no_of_samples_; ++i) { - new_tasks.InsertTask(tasks.GetTaskByID(i)); - } - std::swap(tasks, new_tasks); + ShardTaskList new_tasks; + size_t samples_to_assign = + (no_of_samples_ > 0 && no_of_samples_ < total_no) ? no_of_samples_ : tasks.sample_ids_.size(); + for (size_t i = 0; i < samples_to_assign; ++i) { + new_tasks.AssignTask(tasks, tasks.permutation_[i]); } + ShardTaskList::TaskListSwap(tasks, new_tasks); } } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) - uint32_t individual_size = tasks.Size() / tasks.categories; - std::vector> new_permutations(tasks.categories, std::vector(individual_size)); - for (uint32_t i = 0; i < tasks.categories; i++) { - for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast(j); - std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); - } - tasks.permutation_.clear(); - for (uint32_t j = 0; j < individual_size; j++) { - for (uint32_t i = 0; i < tasks.categories; i++) { - tasks.permutation_.push_back(new_permutations[i][j] * static_cast(tasks.categories) + static_cast(i)); - } - } + return this->CategoryShuffle(tasks); } return SUCCESS; } diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc similarity index 54% rename from mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc rename to mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc index 07c34964a9e..0210d2057d0 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ */ #include "minddata/dataset/util/random.h" -#include "minddata/mindrecord/include/shard_task.h" +#include "minddata/mindrecord/include/shard_task_list.h" #include "utils/ms_utils.h" #include "minddata/mindrecord/include/common/shard_utils.h" @@ -25,55 +25,88 @@ using mindspore::MsLogLevel::DEBUG; namespace mindspore { namespace mindrecord { -ShardTask::ShardTask() : categories(1) {} +ShardTaskList::ShardTaskList() : categories(1) {} -ShardTask::ShardTask(const ShardTask &other) - : categories(other.categories), permutation_(other.permutation_), task_list_(other.task_list_) {} +ShardTaskList::ShardTaskList(const ShardTaskList &other) + : categories(other.categories), + permutation_(other.permutation_), + sample_ids_(other.sample_ids_), + task_list_(other.task_list_) {} -ShardTask &ShardTask::operator=(const ShardTask &other) { - ShardTask tmp(other); +ShardTaskList &ShardTaskList::operator=(const ShardTaskList &other) { + ShardTaskList tmp(other); std::swap(categories, tmp.categories); permutation_.swap(tmp.permutation_); + sample_ids_.swap(tmp.sample_ids_); task_list_.swap(tmp.task_list_); return *this; } -void ShardTask::MakePerm() { - permutation_ = std::vector(task_list_.size()); - for (uint32_t i = 0; i < task_list_.size(); i++) { +void ShardTaskList::InitSampleIds() { + // no-op if there already exists sample ids. Do not clobber previous list + if (sample_ids_.empty()) { + sample_ids_ = std::vector(task_list_.size()); + for (int i = 0; i < task_list_.size(); i++) sample_ids_[i] = i; + } +} + +void ShardTaskList::MakePerm() { + size_t perm_size = sample_ids_.size(); + permutation_ = std::vector(perm_size); + for (uint32_t i = 0; i < perm_size; i++) { permutation_[i] = static_cast(i); } } -void ShardTask::PopBack() { task_list_.pop_back(); } +// Swap the new_tasks with orig_tasks +void ShardTaskList::TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_tasks) { + // When swapping, if the orig_tasks contains fields that need to be preserved after the swap, then swapping with a + // new_tasks that does not have those fields will result in clobbering/losing the data after the swap. + // The task_list_ should not be lost/clobbered. + new_tasks.task_list_ = std::move(orig_tasks.task_list_); -uint32_t ShardTask::Size() const { return static_cast(task_list_.size()); } + // Now, it's safe to drive the swap. + std::swap(orig_tasks, new_tasks); +} -uint32_t ShardTask::SizeOfRows() const { +void ShardTaskList::PopBack() { task_list_.pop_back(); } + +uint32_t ShardTaskList::Size() const { return static_cast(task_list_.size()); } + +uint32_t ShardTaskList::SizeOfRows() const { if (task_list_.size() == 0) return static_cast(0); // 1 task is 1 page - auto sum_num_rows = [](int x, std::tuple, std::vector, json> y) { - return x + std::get<2>(y)[0]; - }; + auto sum_num_rows = [](int x, ShardTask y) { return x + std::get<2>(y)[0]; }; uint32_t nRows = std::accumulate(task_list_.begin(), task_list_.end(), 0, sum_num_rows); return nRows; } -std::tuple, std::vector, json> &ShardTask::GetTaskByID(size_t id) { +ShardTask &ShardTaskList::GetTaskByID(size_t id) { MS_ASSERT(id < task_list_.size()); return task_list_[id]; } -std::tuple, std::vector, json> &ShardTask::GetRandomTask() { +int ShardTaskList::GetTaskSampleByID(size_t id) { + MS_ASSERT(id < sample_ids_.size()); + return sample_ids_[id]; +} + +int ShardTaskList::GetRandomTaskID() { + std::mt19937 gen = mindspore::dataset::GetRandomDevice(); + std::uniform_int_distribution<> dis(0, task_list_.size() - 1); + return dis(gen); +} + +ShardTask &ShardTaskList::GetRandomTask() { std::mt19937 gen = mindspore::dataset::GetRandomDevice(); std::uniform_int_distribution<> dis(0, task_list_.size() - 1); return task_list_[dis(gen)]; } -ShardTask ShardTask::Combine(std::vector &category_tasks, bool replacement, int64_t num_elements, - int64_t num_samples) { - ShardTask res; +ShardTaskList ShardTaskList::Combine(std::vector &category_tasks, bool replacement, int64_t num_elements, + int64_t num_samples) { + ShardTaskList res; if (category_tasks.empty()) return res; auto total_categories = category_tasks.size(); res.categories = static_cast(total_categories); @@ -107,6 +140,7 @@ ShardTask ShardTask::Combine(std::vector &category_tasks, bool replac } } } + return res; } } // namespace mindrecord