!30553 Support dataset reset() to recover after failure

Merge pull request !30553 from h.farahat/reset
This commit is contained in:
i-robot 2022-03-04 02:19:45 +00:00 committed by Gitee
commit b90cf43562
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
26 changed files with 506 additions and 27 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -24,7 +24,8 @@
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(TreeConsumer, 0, ([](const py::module *m) {
(void)py::class_<TreeConsumer, std::shared_ptr<TreeConsumer>>(*m, "TreeConsumer");
(void)py::class_<TreeConsumer, std::shared_ptr<TreeConsumer>>(*m, "TreeConsumer")
.def("Reset", [](TreeConsumer &self, int64_t step) { THROW_IF_ERROR(self.Reset(step)); });
}));
PYBIND_REGISTER(PythonIteratorConsumer, 1, ([](const py::module *m) {
(void)py::class_<PythonIteratorConsumer, TreeConsumer, std::shared_ptr<PythonIteratorConsumer>>(

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -42,7 +42,11 @@ namespace dataset {
using ProfilingRegistrationState = ProfilingManager::ProfilingRegistrationState;
#endif
// TreeConsumer
TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); }
TreeConsumer::TreeConsumer() : TreeConsumer(1) {}
TreeConsumer::TreeConsumer(int32_t num_epochs) : num_epochs_(num_epochs) {
tree_adapter_ = std::make_unique<TreeAdapter>();
}
Status TreeConsumer::Init(std::shared_ptr<DatasetNode> d) {
RETURN_IF_NOT_OK(tree_adapter_->Compile(std::move(d)));
@ -337,6 +341,39 @@ Status ToDevice::Terminate() {
return TreeConsumer::Terminate();
}
Status TreeConsumer::Reset(int64_t step) {
MS_LOG(INFO) << "Resetting TreeConsumer";
MS_LOG(INFO) << "Terminating pipeline with UUID:" << tree_adapter_->tree_->GetUniqueId();
std::shared_ptr<DatasetNode> old_root = tree_adapter_->input_ir_;
this->Stop();
{
#ifdef ENABLE_PYTHON
py::gil_scoped_release gil_release; // release GIL to allow python threads to terminate.
#endif
this->Terminate();
}
#ifdef ENABLE_GPUQUE
// clear the device if GPU is used.
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
if (op != nullptr) {
MS_LOG(INFO) << "Clearing the GPU device";
op->ClearDevice();
}
#endif
tree_adapter_ = std::make_unique<TreeAdapter>(TreeAdapter::UsageFlag::kDeReset);
RETURN_IF_NOT_OK(tree_adapter_->Compile(old_root, num_epochs_, step));
RETURN_IF_NOT_OK(tree_adapter_->Launch());
MS_LOG(INFO) << "Launched a new pipeline after reset. UUID: " << tree_adapter_->tree_->GetUniqueId();
std::shared_ptr<DatasetOp> root2 = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root2 != nullptr, "Root is a nullptr.");
return Status::OK();
}
#ifndef ENABLE_ANDROID
// SaveToDisk
Status SaveToDisk::ValidateParams() {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -40,6 +40,8 @@ class TreeConsumer {
/// Constructor that prepares an empty tree_adapter
TreeConsumer();
explicit TreeConsumer(int32_t num_epochs);
/// \brief Destructor
~TreeConsumer() = default;
/// Initializes the consumer, this involves constructing and preparing the tree.
@ -55,6 +57,16 @@ class TreeConsumer {
/// \return Offload JSON string.
std::string GetOffload();
/// Function to reset the current consumer to the provided step.
/// The consumer will terminate the pipeline and create a new one with skip injected.
/// \param step the step to reset the pipeline to.
/// \return Status error code
Status Reset(int64_t step);
/// Function to stop the consumer.
/// \return Status error code
virtual Status Stop() { return Status::OK(); }
#ifndef ENABLE_SECURITY
virtual Status RegisterProfilingManager();
@ -80,6 +92,8 @@ class TreeConsumer {
/// Method to return the name of the consumer
/// \return string
virtual std::string Name() = 0;
int32_t num_epochs_;
};
/// Consumer that iterates over the dataset and returns the rows one by one as a vector or a map
@ -87,7 +101,7 @@ class IteratorConsumer : public TreeConsumer {
public:
/// Constructor which will call the base class default constructor.
/// \param num_epochs number of epochs. Default to -1 (infinite epochs).
explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {}
explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(num_epochs) {}
~IteratorConsumer() = default;
@ -116,7 +130,6 @@ class IteratorConsumer : public TreeConsumer {
std::string Name() override { return "IteratorConsumer"; }
private:
int32_t num_epochs_;
std::map<int32_t, std::string> column_order_; // key: column id, val: column name
};
@ -182,7 +195,7 @@ class SaveToDisk : public TreeConsumer {
/// Consumer that iterates over the dataset and send it to a device
class ToDevice : public TreeConsumer {
public:
explicit ToDevice(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {}
explicit ToDevice(int32_t num_epochs = -1) : TreeConsumer(num_epochs) {}
~ToDevice() = default;
@ -198,7 +211,7 @@ class ToDevice : public TreeConsumer {
/// Stop to send data to device
/// \return Status error code
virtual Status Stop();
Status Stop() override;
/// Continue to send data to device
/// \return Status error code
@ -212,9 +225,6 @@ class ToDevice : public TreeConsumer {
/// Method to return the name of the consumer
/// \return string
std::string Name() override { return "ToDevice"; }
private:
int32_t num_epochs_;
};
/// Consumer that is used to get some pipeline information

View File

@ -669,6 +669,24 @@ Status DeviceQueueOp::MallocForGPUData(std::vector<device::DataItemGpu> *items,
return Status::OK();
}
Status DeviceQueueOp::ClearDevice() {
MS_LOG(INFO) << "Clearing the data in GPU device: " << device_id_ << " channel: " << channel_name_;
auto release_function = std::bind(&DeviceQueueOp::ReleaseData, this, std::placeholders::_1, std::placeholders::_2);
auto handle = GpuBufferMgr::GetInstance().Open(0, channel_name_, {}, release_function);
if (handle == INVALID_HANDLE) {
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"[Internal ERROR] Failed to open channel for clearing the device.");
}
BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Clear(handle);
CHECK_FAIL_RETURN_UNEXPECTED(!ret, "Failed to clear the device.");
GpuBufferMgr::GetInstance().Close(handle);
GpuBufferMgr::GetInstance().CloseConfirm();
return Status::OK();
}
#endif
Status DeviceQueueOp::SendDataToCPU() {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -87,6 +87,10 @@ class DeviceQueueOp : public PipelineOp {
void StopWaiting() { ascend_keep_waiting_ = false; }
#endif
#ifdef ENABLE_GPUQUE
Status ClearDevice();
#endif
Status GetDataInfo(DATA_INFO *data_info);
// Name: Print()

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -61,7 +61,9 @@ Status SkipOp::GetNextRow(TensorRow *row) {
}
if (row->eoe()) {
UpdateRepeatAndEpochCounter();
skip_count_ = 0;
if (!first_epoch_only_) {
skip_count_ = 0;
}
}
return Status::OK();
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -50,10 +50,14 @@ class SkipOp : public PipelineOp {
std::string Name() const override { return kSkipOp; }
Status GetNextRow(TensorRow *row) override;
void SetFirstEpochOnly(bool first_epoch_only) { first_epoch_only_ = first_epoch_only; }
private:
int32_t max_skips_; // The number of skips that the user requested
int32_t skip_count_; // A counter for the current number of executed skips
bool first_epoch_only_ = false;
std::unique_ptr<ChildIterator> child_iterator_; // An iterator for fetching.
};
} // namespace dataset

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -632,7 +632,7 @@ Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz
*dataset_size = dataset_size_;
return Status::OK();
}
if (!IsSizeDefined()) {
if (!IsSizeDefined() && size_getter != nullptr) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size));
dataset_size_ = *dataset_size;
return Status::OK();

View File

@ -57,6 +57,12 @@ class RootNode : public DatasetNode {
/// \brief Getter of number of epochs
int32_t num_epochs() const { return num_epochs_; }
/// \brief Getter of number of epochs
int64_t step() const { return step_; }
/// \brief Setter of number of epochs
void SetStep(int64_t step) { step_ = step; }
/// \brief Setter of number of epochs
void SetNumEpochs(int32_t num_epochs) override { num_epochs_ = num_epochs; }
@ -78,6 +84,7 @@ class RootNode : public DatasetNode {
private:
int32_t num_epochs_;
int64_t step_; // to support reset
};
} // namespace dataset
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,6 +28,8 @@ namespace mindspore {
namespace dataset {
// Constructor for SkipNode
SkipNode::SkipNode(int32_t count) : skip_count_(count) {}
SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) { this->AddChild(child); }
std::shared_ptr<DatasetNode> SkipNode::Copy() {
@ -42,6 +44,9 @@ Status SkipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
auto op = std::make_shared<SkipOp>(skip_count_);
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
if (first_epoch_only_) {
op->SetFirstEpochOnly(true);
}
node_ops->push_back(op);
return Status::OK();
}

View File

@ -27,6 +27,8 @@ namespace mindspore {
namespace dataset {
class SkipNode : public DatasetNode {
public:
explicit SkipNode(int32_t count);
/// \brief Constructor
explicit SkipNode(std::shared_ptr<DatasetNode> child, int32_t count);
@ -95,8 +97,11 @@ class SkipNode : public DatasetNode {
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
std::shared_ptr<DatasetNode> *result);
void SetFirstEpochOnly(bool flag) { first_epoch_only_ = flag; }
private:
int32_t skip_count_;
bool first_epoch_only_ = false;
};
} // namespace dataset
} // namespace mindspore

View File

@ -6,6 +6,7 @@ set(DATASET_ENGINE_OPT_SRC_FILES
pass.cc
post/auto_worker_pass.cc
post/repeat_pass.cc
pre/add_skip_pass.cc
pre/cache_transform_pass.cc
pre/cache_validation_pass.cc
pre/deep_copy_pass.cc

View File

@ -0,0 +1,105 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/opt/pre/add_skip_pass.h"
#include <algorithm>
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
namespace mindspore {
namespace dataset {
// constructor
AddSkipPass::InjectionFinder::InjectionFinder(const std::shared_ptr<DatasetNode> &node) : injection_point_(nullptr) {}
// Performs finder work for BuildVocabOp that has special rules about skip injection
Status AddSkipPass::InjectionFinder::Visit(std::shared_ptr<RootNode> node, bool *const modified) {
RETURN_UNEXPECTED_IF_NULL(node);
RETURN_UNEXPECTED_IF_NULL(modified);
CHECK_FAIL_RETURN_UNEXPECTED(node->Children().size() > 0,
"Invalid data, the number of children should be greater than zero.");
// The injection is at the child of the root node
injection_point_ = node->Children()[0];
num_epochs_ = node->num_epochs();
step_ = node->step();
return Status::OK();
}
// Performs finder work for BuildVocabOp that has special rules about skip injection
Status AddSkipPass::InjectionFinder::Visit(std::shared_ptr<BuildVocabNode> node, bool *const modified) {
RETURN_UNEXPECTED_IF_NULL(node);
RETURN_UNEXPECTED_IF_NULL(modified);
injection_point_ = nullptr;
return Status::OK();
}
#ifndef ENABLE_ANDROID
// Performs finder work for BuildSentencePieceVocabNode that has special rules about skip injection
Status AddSkipPass::InjectionFinder::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified) {
RETURN_UNEXPECTED_IF_NULL(node);
RETURN_UNEXPECTED_IF_NULL(modified);
injection_point_ = nullptr;
return Status::OK();
}
#endif
Status AddSkipPass::InjectionFinder::VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) {
RETURN_UNEXPECTED_IF_NULL(node);
RETURN_UNEXPECTED_IF_NULL(modified);
CHECK_FAIL_RETURN_UNEXPECTED(node->Children().size() > 0,
"Invalid data, the number of children should be greater than zero.");
// Assumption: There is only one TransferNode in a pipeline. This assumption is not validated here.
// Move the injection point to the child of this node.
injection_point_ = node->Children()[0];
return Status::OK();
}
// Runs an injection pass to inject in operators needed at the pre pass stage
Status AddSkipPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) {
RETURN_UNEXPECTED_IF_NULL(root_ir);
RETURN_UNEXPECTED_IF_NULL(modified);
MS_LOG(INFO) << "Pre pass: Injection pass started.";
// First, run the finder to perform any injection info before we can go ahead to drive the op injection work.
// The finder can make updates to the AddSkipPass object.
AddSkipPass::InjectionFinder finder(root_ir);
RETURN_IF_NOT_OK(finder.Run(root_ir, modified));
// The first injection logic is to check if we should inject the skip op as the root node.
std::shared_ptr<DatasetNode> node = finder.injection_point();
CHECK_FAIL_RETURN_UNEXPECTED(node != nullptr, "Failed to inject SkipOp.");
int64_t dataset_size = -1;
RETURN_IF_NOT_OK(root_ir->GetDatasetSize(nullptr, false, &dataset_size));
CHECK_FAIL_RETURN_UNEXPECTED(dataset_size > 0, "Cannot reset the pipeline, dataset size is undefined");
int32_t num_epochs = finder.GetNumEpochs();
int64_t step = finder.GetStep();
int32_t new_num_epochs = num_epochs - static_cast<int32_t>(step / dataset_size);
int64_t skip_num = step % dataset_size;
root_ir->SetNumEpochs(new_num_epochs);
auto skip_node = std::make_shared<SkipNode>(skip_num);
skip_node->SetFirstEpochOnly(true);
RETURN_IF_NOT_OK(node->InsertAbove(skip_node));
MS_LOG(INFO) << "Pre pass: Injection pass complete.";
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,99 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_ADD_SKIP_PASS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_ADD_SKIP_PASS_H_
#include <memory>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
class DatasetOp;
/// \class AddSkipPass
/// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api
/// parsing.
class AddSkipPass : public IRTreePass {
/// \class InjectionFinder
/// \brief This is a nested node pass class whose job is to parse the tree and perform any identification logic for
/// operators that need to be injected. It is run first by the main injection pass to find out what operators
/// it may need to inject.
class InjectionFinder : public IRNodePass {
public:
/// \brief Constructor
explicit InjectionFinder(const std::shared_ptr<DatasetNode> &node);
/// \brief Destructor
~InjectionFinder() = default;
/// \brief Performs finder work for RootNode that has special rules about skip injection.
/// \param[in] node The node being visited
/// \param[in, out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<RootNode> node, bool *const modified) override;
/// \brief Performs finder work for BuildVocabNode that has special rules about skip injection.
/// \param[in] node The node being visited
/// \param[in, out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<BuildVocabNode> node, bool *const modified) override;
#ifndef ENABLE_ANDROID
/// \brief Performs finder work for BuildSentenceVocabNode that has special rules about skip injection.
/// \param[in] node The node being visited
/// \param[in, out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified) override;
#endif
/// \brief Register the TransferNode for further action.
/// \param[in] node The node being visited
/// \param[in, out] modified Indicator if the node was changed at all
/// \return Status The status code returned
Status VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) override;
/// \brief Getter
std::shared_ptr<DatasetNode> injection_point() { return injection_point_; }
int64_t GetStep() { return step_; }
int32_t GetNumEpochs() { return num_epochs_; }
private:
std::shared_ptr<DatasetNode> injection_point_;
int64_t step_{};
int32_t num_epochs_{};
};
public:
/// \brief Constructor
AddSkipPass() {}
/// \brief Destructor
~AddSkipPass() override = default;
/// \brief Runs an injection pass to inject in operators needed at the pre pass stage
/// \param[in, out] tree The tree to operate on.
/// \param[in, out] Indicate of the tree was modified.
/// \return Status The status code returned
Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) override;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_ADD_SKIP_PASS_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,6 +29,7 @@
#ifdef ENABLE_PYTHON
#include "minddata/dataset/engine/opt/post/generator_node_pass.h"
#endif
#include "minddata/dataset/engine/opt/pre/add_skip_pass.h"
#include "minddata/dataset/engine/opt/pre/cache_validation_pass.h"
#include "minddata/dataset/engine/opt/pre/deep_copy_pass.h"
#include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h"
@ -57,6 +58,7 @@ Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) {
actions.emplace_back(std::make_unique<InputValidationPass>());
actions.emplace_back(std::make_unique<CacheValidationPass>());
actions.emplace_back(std::make_unique<NodeRemovalPass>());
if (usage_ == kDeReset) actions.emplace_back(std::make_unique<AddSkipPass>());
actions.emplace_back(std::make_unique<EpochCtrlPass>());
if (usage_ == kDeGetter) actions.emplace_back(std::make_unique<GetterPass>());
#ifndef ENABLE_ANDROID
@ -176,9 +178,9 @@ Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir) {
return Status::OK();
}
Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_epochs) {
Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_epochs, int64_t step) {
RETURN_UNEXPECTED_IF_NULL(input_ir);
input_ir_ = input_ir;
tree_state_ = kCompileStateIRGraphBuilt;
MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n';
@ -194,6 +196,7 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_e
RETURN_IF_NOT_OK(cloning_tree.Run(input_ir, &m));
std::shared_ptr<RootNode> root_ir = cloning_tree.Root();
root_ir->SetNumEpochs(num_epochs);
root_ir->SetStep(step);
tree_state_ = kCompileStateIRTreeCloned;
MS_LOG(INFO) << "Plan before optimization:" << '\n' << *root_ir << '\n';

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -49,7 +49,7 @@ class TreeAdapter {
// this flag is used to indicate the purpose of the creation of this tree adapter (type of the tree_consumer).
// Currently there are 3 types of consumer, Iterator, Getter and TDT/Vocab/Save ...
// To avoid premature optimization, the last type (TDT/Vocab/Save) is regarded as Iterator for now.
enum UsageFlag { kDeIterator = 0, kDeGetter = 1 };
enum UsageFlag { kDeIterator = 0, kDeGetter = 1, kDeReset = 2 };
explicit TreeAdapter(UsageFlag flag = kDeIterator);
@ -57,7 +57,7 @@ class TreeAdapter {
// This function performs syntax checking, semantics checking, optimizes, and then builds
// the Execution tree.
Status Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_epochs = -1);
Status Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_epochs = -1, int64_t step = 0);
// Return the root node of the IR after cloned from the parsed IR tree
std::shared_ptr<DatasetNode> RootIRNode() const { return root_ir_; }
@ -119,6 +119,7 @@ class TreeAdapter {
Status BuildExecutionTreeRecur(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op);
std::unordered_map<std::string, int32_t> column_name_map_;
std::shared_ptr<DatasetNode> input_ir_;
std::shared_ptr<DatasetNode> root_ir_;
std::unique_ptr<ExecutionTree> tree_;
bool optimize_; // Flag to enable optional optimization pass

View File

@ -142,6 +142,22 @@ BlockQueueStatus_T BlockingQueue::Pop() {
return SUCCESS;
}
BlockQueueStatus_T BlockingQueue::Clear() {
std::unique_lock<std::mutex> locker(mutex_);
while (Size() > 0) {
std::vector<DataItemGpu> data;
auto ret = queue_->Front(&data);
if (ret) {
return ret;
}
ret = queue_->Pop();
if (ret) {
return ret;
}
}
return SUCCESS;
}
bool BlockingQueue::Destroy() {
if (queue_ != nullptr) {
return queue_->Destroy();

View File

@ -89,6 +89,7 @@ class BlockingQueue {
BlockQueueStatus_T Push(const std::vector<DataItemGpu> &data, unsigned int timeout_in_sec);
BlockQueueStatus_T Front(std::vector<DataItemGpu> *data);
BlockQueueStatus_T Pop();
BlockQueueStatus_T Clear();
bool Destroy();
size_t Size() { return queue_->Size(); }
size_t Capacity() { return queue_->Capacity(); }

View File

@ -130,6 +130,14 @@ BlockQueueStatus_T GpuBufferMgr::Pop(unsigned int handle) {
return iter->second->Pop();
}
BlockQueueStatus_T GpuBufferMgr::Clear(unsigned int handle) {
auto iter = handle_queue_map_.find(handle);
if (iter == handle_queue_map_.end()) {
return HANDLE_NOT_EXIST;
}
return iter->second->Clear();
}
void GpuBufferMgr::Close(unsigned int handle) noexcept {
if (!handle_queue_map_.count(handle)) {
return;

View File

@ -84,6 +84,7 @@ class GpuBufferMgr {
unsigned int timeout_in_sec);
EXPORT BlockQueueStatus_T Front(unsigned int handle, std::vector<DataItemGpu> *data);
EXPORT BlockQueueStatus_T Pop(unsigned int handle);
EXPORT BlockQueueStatus_T Clear(unsigned int handle);
EXPORT void set_device_id(int device_id);

View File

@ -144,6 +144,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
${MINDDATA_DIR}/engine/ir/datasetops/repeat_node.cc
${MINDDATA_DIR}/engine/ir/datasetops/project_node.cc
${MINDDATA_DIR}/engine/ir/datasetops/shuffle_node.cc
${MINDDATA_DIR}/engine/ir/datasetops/skip_node.cc
${MINDDATA_DIR}/engine/ir/datasetops/source/album_node.cc
${MINDDATA_DIR}/engine/ir/datasetops/source/mnist_node.cc
${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/distributed_sampler_ir.cc
@ -161,6 +162,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
${MINDDATA_DIR}/engine/datasetops/device_queue_op.cc
${MINDDATA_DIR}/engine/datasetops/project_op.cc
${MINDDATA_DIR}/engine/datasetops/shuffle_op.cc
${MINDDATA_DIR}/engine/datasetops/skip_op.cc
${MINDDATA_DIR}/engine/datasetops/pipeline_op.cc
${MINDDATA_DIR}/engine/datasetops/batch_op.cc
${MINDDATA_DIR}/engine/datasetops/map_op/map_op.cc
@ -170,6 +172,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
${MINDDATA_DIR}/engine/datasetops/source/mappable_leaf_op.cc
${MINDDATA_DIR}/engine/datasetops/source/io_block.cc
${MINDDATA_DIR}/engine/opt/pre/add_skip_pass.cc
${MINDDATA_DIR}/engine/opt/pre/getter_pass.cc
${MINDDATA_DIR}/engine/opt/pre/input_validation_pass.cc
${MINDDATA_DIR}/engine/opt/pre/cache_validation_pass.cc

View File

@ -86,6 +86,43 @@ OffloadToManualOffloadMode = {
True: cde.ManualOffloadMode.ENABLED
}
_train_dataset = None
def _set_training_dataset(dataset):
"""
Set the dataset to be used when training recovery has occurred.
Args:
dataset: the training dataset or iterator
"""
global _train_dataset
_train_dataset = dataset
def _get_training_dataset():
"""
Get the dataset to be used when training recovery has occurred.
Returns:
training dataset/iterator
"""
return _train_dataset
def _reset_training_dataset(step):
"""
Reset the training dataset to the given step number.
Args:
step (int): Global step number.
"""
dataset = _get_training_dataset()
if dataset is not None:
dataset.reset(step)
else:
raise RuntimeError("Training dataset is not set.")
class Shuffle(str, Enum):
"""Specify the shuffle mode.
@ -3352,6 +3389,9 @@ class _ToDevice:
def send(self):
self._to_device.Send()
def reset(self, step):
self._to_device.Reset(step)
def stop_send(self):
"""
send stop send signal to pipeline, it is used when end of sequence is sent at the epoch end.
@ -3459,6 +3499,11 @@ class TransferDataset(Dataset):
if self._to_device is not None:
self._to_device.continue_send()
def reset(self, step):
if self._to_device is not None:
logger.info("Reset the dataset pipeline to step " + str(step))
self._to_device.reset(step)
def get_data_info(self):
"""
Get type and shape of current batch

View File

@ -178,6 +178,15 @@ class Iterator:
self._getters()
return self._col_names
def reset(self, step):
"""
Reset the iterator to the given step number.
Args:
step (int): Global step number.
"""
self._iterator.Reset(step)
class DictIterator(Iterator):
"""

View File

@ -330,6 +330,10 @@ class DatasetHelper:
"""Continue to send data to device at the beginning of epoch."""
self.iter.continue_send()
def reset(self, step):
"""Reset the dataset to the provided step."""
self.iter.reset(step)
def get_data_info(self):
"""
In sink mode, it returns the types and shapes of the current data.
@ -383,6 +387,7 @@ class _DatasetIter:
self.stop_send = dataset.__transfer_dataset__.stop_send
self.release = dataset.__transfer_dataset__.release
self.continue_send = dataset.__transfer_dataset__.continue_send
self.reset = dataset.__transfer_dataset__.reset
self.get_data_info = dataset.__transfer_dataset__.get_data_info
self.dynamic_min_max_shapes = dataset.dynamic_min_max_shapes
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)

View File

@ -1,4 +1,4 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -40,7 +40,7 @@ from ..parallel._cost_model_context import _set_multi_subgraphs
from .dataset_helper import DatasetHelper, connect_network_with_dataset
from . import amp
from ..common.api import _pynative_executor, _cell_graph_executor
from ..dataset.engine.datasets import _set_training_dataset
def _transfer_tensor_to_tuple(inputs):
"""
@ -393,6 +393,10 @@ class Model:
if dataset_sink_mode:
network = connect_network_with_dataset(network, dataset_helper)
if is_train:
_set_training_dataset(dataset_helper) # pylint: disable=W0212
network.set_train(is_train)
network.phase = phase

View File

@ -0,0 +1,85 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing pipeline Reset
"""
import numpy as np
import pytest
import mindspore.dataset as ds
def create_np_dataset(size):
data = ds.NumpySlicesDataset(list(range(1, size + 1)), shuffle=False)
return data
def util(data, num_epochs, failure_point: int, reset_step):
size = data.get_dataset_size()
expected = []
expected_itr = data.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True)
for _ in range(num_epochs):
for d in expected_itr:
expected.append(d)
del expected_itr
actual_before_reset = []
itr = data.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True)
ds.engine.datasets._set_training_dataset(itr) # pylint: disable=W0212
cur_step: int = 0
failed = False
for _ in range(num_epochs):
for d in itr:
actual_before_reset.append(d)
if cur_step == failure_point:
ds.engine.datasets._reset_training_dataset(reset_step) # pylint: disable=W0212
failed = True
break
cur_step += 1
if failed:
break
actual_after_reset = []
if failed:
for _ in range(reset_step // size, num_epochs):
for d in itr:
actual_after_reset.append(d)
with pytest.raises(RuntimeError, match="User tries to fetch data beyond the specified number of epochs."):
for _ in itr:
pass
for x, y in zip(expected[:failure_point], actual_before_reset):
np.testing.assert_array_equal(x, y)
for x, y in zip(expected[reset_step:], actual_after_reset):
np.testing.assert_array_equal(x, y)
def test_reset():
"""
Feature: dataset recovery
Description: Simple test of data pipeline reset feature
Expectation: same datasets after reset
"""
dataset_size = 5
num_epochs = 3
data = create_np_dataset(size=dataset_size)
for failure_point in range(dataset_size * num_epochs):
for reset_step in range(dataset_size * num_epochs):
util(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step)
if __name__ == "__main__":
test_reset()