forked from mindspore-Ecosystem/mindspore
!30553 Support dataset reset() to recover after failure
Merge pull request !30553 from h.farahat/reset
This commit is contained in:
commit
b90cf43562
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -24,7 +24,8 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
PYBIND_REGISTER(TreeConsumer, 0, ([](const py::module *m) {
|
||||
(void)py::class_<TreeConsumer, std::shared_ptr<TreeConsumer>>(*m, "TreeConsumer");
|
||||
(void)py::class_<TreeConsumer, std::shared_ptr<TreeConsumer>>(*m, "TreeConsumer")
|
||||
.def("Reset", [](TreeConsumer &self, int64_t step) { THROW_IF_ERROR(self.Reset(step)); });
|
||||
}));
|
||||
PYBIND_REGISTER(PythonIteratorConsumer, 1, ([](const py::module *m) {
|
||||
(void)py::class_<PythonIteratorConsumer, TreeConsumer, std::shared_ptr<PythonIteratorConsumer>>(
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -42,7 +42,11 @@ namespace dataset {
|
|||
using ProfilingRegistrationState = ProfilingManager::ProfilingRegistrationState;
|
||||
#endif
|
||||
// TreeConsumer
|
||||
TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); }
|
||||
TreeConsumer::TreeConsumer() : TreeConsumer(1) {}
|
||||
|
||||
TreeConsumer::TreeConsumer(int32_t num_epochs) : num_epochs_(num_epochs) {
|
||||
tree_adapter_ = std::make_unique<TreeAdapter>();
|
||||
}
|
||||
|
||||
Status TreeConsumer::Init(std::shared_ptr<DatasetNode> d) {
|
||||
RETURN_IF_NOT_OK(tree_adapter_->Compile(std::move(d)));
|
||||
|
@ -337,6 +341,39 @@ Status ToDevice::Terminate() {
|
|||
return TreeConsumer::Terminate();
|
||||
}
|
||||
|
||||
Status TreeConsumer::Reset(int64_t step) {
|
||||
MS_LOG(INFO) << "Resetting TreeConsumer";
|
||||
|
||||
MS_LOG(INFO) << "Terminating pipeline with UUID:" << tree_adapter_->tree_->GetUniqueId();
|
||||
std::shared_ptr<DatasetNode> old_root = tree_adapter_->input_ir_;
|
||||
this->Stop();
|
||||
{
|
||||
#ifdef ENABLE_PYTHON
|
||||
py::gil_scoped_release gil_release; // release GIL to allow python threads to terminate.
|
||||
#endif
|
||||
this->Terminate();
|
||||
}
|
||||
|
||||
#ifdef ENABLE_GPUQUE
|
||||
// clear the device if GPU is used.
|
||||
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
|
||||
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get());
|
||||
if (op != nullptr) {
|
||||
MS_LOG(INFO) << "Clearing the GPU device";
|
||||
op->ClearDevice();
|
||||
}
|
||||
#endif
|
||||
|
||||
tree_adapter_ = std::make_unique<TreeAdapter>(TreeAdapter::UsageFlag::kDeReset);
|
||||
RETURN_IF_NOT_OK(tree_adapter_->Compile(old_root, num_epochs_, step));
|
||||
RETURN_IF_NOT_OK(tree_adapter_->Launch());
|
||||
MS_LOG(INFO) << "Launched a new pipeline after reset. UUID: " << tree_adapter_->tree_->GetUniqueId();
|
||||
std::shared_ptr<DatasetOp> root2 = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(root2 != nullptr, "Root is a nullptr.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// SaveToDisk
|
||||
Status SaveToDisk::ValidateParams() {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -40,6 +40,8 @@ class TreeConsumer {
|
|||
/// Constructor that prepares an empty tree_adapter
|
||||
TreeConsumer();
|
||||
|
||||
explicit TreeConsumer(int32_t num_epochs);
|
||||
|
||||
/// \brief Destructor
|
||||
~TreeConsumer() = default;
|
||||
/// Initializes the consumer, this involves constructing and preparing the tree.
|
||||
|
@ -55,6 +57,16 @@ class TreeConsumer {
|
|||
/// \return Offload JSON string.
|
||||
std::string GetOffload();
|
||||
|
||||
/// Function to reset the current consumer to the provided step.
|
||||
/// The consumer will terminate the pipeline and create a new one with skip injected.
|
||||
/// \param step the step to reset the pipeline to.
|
||||
/// \return Status error code
|
||||
Status Reset(int64_t step);
|
||||
|
||||
/// Function to stop the consumer.
|
||||
/// \return Status error code
|
||||
virtual Status Stop() { return Status::OK(); }
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
virtual Status RegisterProfilingManager();
|
||||
|
||||
|
@ -80,6 +92,8 @@ class TreeConsumer {
|
|||
/// Method to return the name of the consumer
|
||||
/// \return string
|
||||
virtual std::string Name() = 0;
|
||||
|
||||
int32_t num_epochs_;
|
||||
};
|
||||
|
||||
/// Consumer that iterates over the dataset and returns the rows one by one as a vector or a map
|
||||
|
@ -87,7 +101,7 @@ class IteratorConsumer : public TreeConsumer {
|
|||
public:
|
||||
/// Constructor which will call the base class default constructor.
|
||||
/// \param num_epochs number of epochs. Default to -1 (infinite epochs).
|
||||
explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {}
|
||||
explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(num_epochs) {}
|
||||
|
||||
~IteratorConsumer() = default;
|
||||
|
||||
|
@ -116,7 +130,6 @@ class IteratorConsumer : public TreeConsumer {
|
|||
std::string Name() override { return "IteratorConsumer"; }
|
||||
|
||||
private:
|
||||
int32_t num_epochs_;
|
||||
std::map<int32_t, std::string> column_order_; // key: column id, val: column name
|
||||
};
|
||||
|
||||
|
@ -182,7 +195,7 @@ class SaveToDisk : public TreeConsumer {
|
|||
/// Consumer that iterates over the dataset and send it to a device
|
||||
class ToDevice : public TreeConsumer {
|
||||
public:
|
||||
explicit ToDevice(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {}
|
||||
explicit ToDevice(int32_t num_epochs = -1) : TreeConsumer(num_epochs) {}
|
||||
|
||||
~ToDevice() = default;
|
||||
|
||||
|
@ -198,7 +211,7 @@ class ToDevice : public TreeConsumer {
|
|||
|
||||
/// Stop to send data to device
|
||||
/// \return Status error code
|
||||
virtual Status Stop();
|
||||
Status Stop() override;
|
||||
|
||||
/// Continue to send data to device
|
||||
/// \return Status error code
|
||||
|
@ -212,9 +225,6 @@ class ToDevice : public TreeConsumer {
|
|||
/// Method to return the name of the consumer
|
||||
/// \return string
|
||||
std::string Name() override { return "ToDevice"; }
|
||||
|
||||
private:
|
||||
int32_t num_epochs_;
|
||||
};
|
||||
|
||||
/// Consumer that is used to get some pipeline information
|
||||
|
|
|
@ -669,6 +669,24 @@ Status DeviceQueueOp::MallocForGPUData(std::vector<device::DataItemGpu> *items,
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeviceQueueOp::ClearDevice() {
|
||||
MS_LOG(INFO) << "Clearing the data in GPU device: " << device_id_ << " channel: " << channel_name_;
|
||||
auto release_function = std::bind(&DeviceQueueOp::ReleaseData, this, std::placeholders::_1, std::placeholders::_2);
|
||||
auto handle = GpuBufferMgr::GetInstance().Open(0, channel_name_, {}, release_function);
|
||||
if (handle == INVALID_HANDLE) {
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"[Internal ERROR] Failed to open channel for clearing the device.");
|
||||
}
|
||||
|
||||
BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Clear(handle);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!ret, "Failed to clear the device.");
|
||||
|
||||
GpuBufferMgr::GetInstance().Close(handle);
|
||||
GpuBufferMgr::GetInstance().CloseConfirm();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
Status DeviceQueueOp::SendDataToCPU() {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -87,6 +87,10 @@ class DeviceQueueOp : public PipelineOp {
|
|||
void StopWaiting() { ascend_keep_waiting_ = false; }
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_GPUQUE
|
||||
Status ClearDevice();
|
||||
#endif
|
||||
|
||||
Status GetDataInfo(DATA_INFO *data_info);
|
||||
|
||||
// Name: Print()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -61,8 +61,10 @@ Status SkipOp::GetNextRow(TensorRow *row) {
|
|||
}
|
||||
if (row->eoe()) {
|
||||
UpdateRepeatAndEpochCounter();
|
||||
if (!first_epoch_only_) {
|
||||
skip_count_ = 0;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -50,10 +50,14 @@ class SkipOp : public PipelineOp {
|
|||
std::string Name() const override { return kSkipOp; }
|
||||
Status GetNextRow(TensorRow *row) override;
|
||||
|
||||
void SetFirstEpochOnly(bool first_epoch_only) { first_epoch_only_ = first_epoch_only; }
|
||||
|
||||
private:
|
||||
int32_t max_skips_; // The number of skips that the user requested
|
||||
int32_t skip_count_; // A counter for the current number of executed skips
|
||||
|
||||
bool first_epoch_only_ = false;
|
||||
|
||||
std::unique_ptr<ChildIterator> child_iterator_; // An iterator for fetching.
|
||||
};
|
||||
} // namespace dataset
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -632,7 +632,7 @@ Status DatasetNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz
|
|||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
if (!IsSizeDefined()) {
|
||||
if (!IsSizeDefined() && size_getter != nullptr) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size));
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
|
|
|
@ -57,6 +57,12 @@ class RootNode : public DatasetNode {
|
|||
/// \brief Getter of number of epochs
|
||||
int32_t num_epochs() const { return num_epochs_; }
|
||||
|
||||
/// \brief Getter of number of epochs
|
||||
int64_t step() const { return step_; }
|
||||
|
||||
/// \brief Setter of number of epochs
|
||||
void SetStep(int64_t step) { step_ = step; }
|
||||
|
||||
/// \brief Setter of number of epochs
|
||||
void SetNumEpochs(int32_t num_epochs) override { num_epochs_ = num_epochs; }
|
||||
|
||||
|
@ -78,6 +84,7 @@ class RootNode : public DatasetNode {
|
|||
|
||||
private:
|
||||
int32_t num_epochs_;
|
||||
int64_t step_; // to support reset
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -28,6 +28,8 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
|
||||
// Constructor for SkipNode
|
||||
SkipNode::SkipNode(int32_t count) : skip_count_(count) {}
|
||||
|
||||
SkipNode::SkipNode(std::shared_ptr<DatasetNode> child, int32_t count) : skip_count_(count) { this->AddChild(child); }
|
||||
|
||||
std::shared_ptr<DatasetNode> SkipNode::Copy() {
|
||||
|
@ -42,6 +44,9 @@ Status SkipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
auto op = std::make_shared<SkipOp>(skip_count_);
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
if (first_epoch_only_) {
|
||||
op->SetFirstEpochOnly(true);
|
||||
}
|
||||
node_ops->push_back(op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -27,6 +27,8 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
class SkipNode : public DatasetNode {
|
||||
public:
|
||||
explicit SkipNode(int32_t count);
|
||||
|
||||
/// \brief Constructor
|
||||
explicit SkipNode(std::shared_ptr<DatasetNode> child, int32_t count);
|
||||
|
||||
|
@ -95,8 +97,11 @@ class SkipNode : public DatasetNode {
|
|||
static Status from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result);
|
||||
|
||||
void SetFirstEpochOnly(bool flag) { first_epoch_only_ = flag; }
|
||||
|
||||
private:
|
||||
int32_t skip_count_;
|
||||
bool first_epoch_only_ = false;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -6,6 +6,7 @@ set(DATASET_ENGINE_OPT_SRC_FILES
|
|||
pass.cc
|
||||
post/auto_worker_pass.cc
|
||||
post/repeat_pass.cc
|
||||
pre/add_skip_pass.cc
|
||||
pre/cache_transform_pass.cc
|
||||
pre/cache_validation_pass.cc
|
||||
pre/deep_copy_pass.cc
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/opt/pre/add_skip_pass.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/root_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// constructor
|
||||
AddSkipPass::InjectionFinder::InjectionFinder(const std::shared_ptr<DatasetNode> &node) : injection_point_(nullptr) {}
|
||||
|
||||
// Performs finder work for BuildVocabOp that has special rules about skip injection
|
||||
Status AddSkipPass::InjectionFinder::Visit(std::shared_ptr<RootNode> node, bool *const modified) {
|
||||
RETURN_UNEXPECTED_IF_NULL(node);
|
||||
RETURN_UNEXPECTED_IF_NULL(modified);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(node->Children().size() > 0,
|
||||
"Invalid data, the number of children should be greater than zero.");
|
||||
// The injection is at the child of the root node
|
||||
injection_point_ = node->Children()[0];
|
||||
num_epochs_ = node->num_epochs();
|
||||
step_ = node->step();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Performs finder work for BuildVocabOp that has special rules about skip injection
|
||||
Status AddSkipPass::InjectionFinder::Visit(std::shared_ptr<BuildVocabNode> node, bool *const modified) {
|
||||
RETURN_UNEXPECTED_IF_NULL(node);
|
||||
RETURN_UNEXPECTED_IF_NULL(modified);
|
||||
injection_point_ = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// Performs finder work for BuildSentencePieceVocabNode that has special rules about skip injection
|
||||
Status AddSkipPass::InjectionFinder::Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified) {
|
||||
RETURN_UNEXPECTED_IF_NULL(node);
|
||||
RETURN_UNEXPECTED_IF_NULL(modified);
|
||||
injection_point_ = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
Status AddSkipPass::InjectionFinder::VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) {
|
||||
RETURN_UNEXPECTED_IF_NULL(node);
|
||||
RETURN_UNEXPECTED_IF_NULL(modified);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(node->Children().size() > 0,
|
||||
"Invalid data, the number of children should be greater than zero.");
|
||||
// Assumption: There is only one TransferNode in a pipeline. This assumption is not validated here.
|
||||
// Move the injection point to the child of this node.
|
||||
injection_point_ = node->Children()[0];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Runs an injection pass to inject in operators needed at the pre pass stage
|
||||
Status AddSkipPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) {
|
||||
RETURN_UNEXPECTED_IF_NULL(root_ir);
|
||||
RETURN_UNEXPECTED_IF_NULL(modified);
|
||||
MS_LOG(INFO) << "Pre pass: Injection pass started.";
|
||||
|
||||
// First, run the finder to perform any injection info before we can go ahead to drive the op injection work.
|
||||
// The finder can make updates to the AddSkipPass object.
|
||||
AddSkipPass::InjectionFinder finder(root_ir);
|
||||
RETURN_IF_NOT_OK(finder.Run(root_ir, modified));
|
||||
|
||||
// The first injection logic is to check if we should inject the skip op as the root node.
|
||||
std::shared_ptr<DatasetNode> node = finder.injection_point();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(node != nullptr, "Failed to inject SkipOp.");
|
||||
|
||||
int64_t dataset_size = -1;
|
||||
RETURN_IF_NOT_OK(root_ir->GetDatasetSize(nullptr, false, &dataset_size));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(dataset_size > 0, "Cannot reset the pipeline, dataset size is undefined");
|
||||
int32_t num_epochs = finder.GetNumEpochs();
|
||||
int64_t step = finder.GetStep();
|
||||
int32_t new_num_epochs = num_epochs - static_cast<int32_t>(step / dataset_size);
|
||||
int64_t skip_num = step % dataset_size;
|
||||
|
||||
root_ir->SetNumEpochs(new_num_epochs);
|
||||
|
||||
auto skip_node = std::make_shared<SkipNode>(skip_num);
|
||||
skip_node->SetFirstEpochOnly(true);
|
||||
RETURN_IF_NOT_OK(node->InsertAbove(skip_node));
|
||||
|
||||
MS_LOG(INFO) << "Pre pass: Injection pass complete.";
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,99 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_ADD_SKIP_PASS_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_ADD_SKIP_PASS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class DatasetOp;
|
||||
|
||||
/// \class AddSkipPass
|
||||
/// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api
|
||||
/// parsing.
|
||||
class AddSkipPass : public IRTreePass {
|
||||
/// \class InjectionFinder
|
||||
/// \brief This is a nested node pass class whose job is to parse the tree and perform any identification logic for
|
||||
/// operators that need to be injected. It is run first by the main injection pass to find out what operators
|
||||
/// it may need to inject.
|
||||
class InjectionFinder : public IRNodePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit InjectionFinder(const std::shared_ptr<DatasetNode> &node);
|
||||
|
||||
/// \brief Destructor
|
||||
~InjectionFinder() = default;
|
||||
|
||||
/// \brief Performs finder work for RootNode that has special rules about skip injection.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[in, out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<RootNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Performs finder work for BuildVocabNode that has special rules about skip injection.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[in, out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<BuildVocabNode> node, bool *const modified) override;
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
/// \brief Performs finder work for BuildSentenceVocabNode that has special rules about skip injection.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[in, out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<BuildSentenceVocabNode> node, bool *const modified) override;
|
||||
#endif
|
||||
|
||||
/// \brief Register the TransferNode for further action.
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[in, out] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status VisitAfter(std::shared_ptr<TransferNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Getter
|
||||
std::shared_ptr<DatasetNode> injection_point() { return injection_point_; }
|
||||
|
||||
int64_t GetStep() { return step_; }
|
||||
|
||||
int32_t GetNumEpochs() { return num_epochs_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<DatasetNode> injection_point_;
|
||||
int64_t step_{};
|
||||
int32_t num_epochs_{};
|
||||
};
|
||||
|
||||
public:
|
||||
/// \brief Constructor
|
||||
AddSkipPass() {}
|
||||
|
||||
/// \brief Destructor
|
||||
~AddSkipPass() override = default;
|
||||
|
||||
/// \brief Runs an injection pass to inject in operators needed at the pre pass stage
|
||||
/// \param[in, out] tree The tree to operate on.
|
||||
/// \param[in, out] Indicate of the tree was modified.
|
||||
/// \return Status The status code returned
|
||||
Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *const modified) override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PRE_ADD_SKIP_PASS_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -29,6 +29,7 @@
|
|||
#ifdef ENABLE_PYTHON
|
||||
#include "minddata/dataset/engine/opt/post/generator_node_pass.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/opt/pre/add_skip_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/cache_validation_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/deep_copy_pass.h"
|
||||
#include "minddata/dataset/engine/opt/pre/epoch_ctrl_pass.h"
|
||||
|
@ -57,6 +58,7 @@ Status TreeAdapter::PrePass(std::shared_ptr<DatasetNode> ir) {
|
|||
actions.emplace_back(std::make_unique<InputValidationPass>());
|
||||
actions.emplace_back(std::make_unique<CacheValidationPass>());
|
||||
actions.emplace_back(std::make_unique<NodeRemovalPass>());
|
||||
if (usage_ == kDeReset) actions.emplace_back(std::make_unique<AddSkipPass>());
|
||||
actions.emplace_back(std::make_unique<EpochCtrlPass>());
|
||||
if (usage_ == kDeGetter) actions.emplace_back(std::make_unique<GetterPass>());
|
||||
#ifndef ENABLE_ANDROID
|
||||
|
@ -176,9 +178,9 @@ Status TreeAdapter::Build(std::shared_ptr<DatasetNode> root_ir) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_epochs) {
|
||||
Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_epochs, int64_t step) {
|
||||
RETURN_UNEXPECTED_IF_NULL(input_ir);
|
||||
|
||||
input_ir_ = input_ir;
|
||||
tree_state_ = kCompileStateIRGraphBuilt;
|
||||
MS_LOG(INFO) << "Input plan:" << '\n' << *input_ir << '\n';
|
||||
|
||||
|
@ -194,6 +196,7 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> input_ir, int32_t num_e
|
|||
RETURN_IF_NOT_OK(cloning_tree.Run(input_ir, &m));
|
||||
std::shared_ptr<RootNode> root_ir = cloning_tree.Root();
|
||||
root_ir->SetNumEpochs(num_epochs);
|
||||
root_ir->SetStep(step);
|
||||
|
||||
tree_state_ = kCompileStateIRTreeCloned;
|
||||
MS_LOG(INFO) << "Plan before optimization:" << '\n' << *root_ir << '\n';
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -49,7 +49,7 @@ class TreeAdapter {
|
|||
// this flag is used to indicate the purpose of the creation of this tree adapter (type of the tree_consumer).
|
||||
// Currently there are 3 types of consumer, Iterator, Getter and TDT/Vocab/Save ...
|
||||
// To avoid premature optimization, the last type (TDT/Vocab/Save) is regarded as Iterator for now.
|
||||
enum UsageFlag { kDeIterator = 0, kDeGetter = 1 };
|
||||
enum UsageFlag { kDeIterator = 0, kDeGetter = 1, kDeReset = 2 };
|
||||
|
||||
explicit TreeAdapter(UsageFlag flag = kDeIterator);
|
||||
|
||||
|
@ -57,7 +57,7 @@ class TreeAdapter {
|
|||
|
||||
// This function performs syntax checking, semantics checking, optimizes, and then builds
|
||||
// the Execution tree.
|
||||
Status Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_epochs = -1);
|
||||
Status Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_epochs = -1, int64_t step = 0);
|
||||
|
||||
// Return the root node of the IR after cloned from the parsed IR tree
|
||||
std::shared_ptr<DatasetNode> RootIRNode() const { return root_ir_; }
|
||||
|
@ -119,6 +119,7 @@ class TreeAdapter {
|
|||
Status BuildExecutionTreeRecur(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op);
|
||||
|
||||
std::unordered_map<std::string, int32_t> column_name_map_;
|
||||
std::shared_ptr<DatasetNode> input_ir_;
|
||||
std::shared_ptr<DatasetNode> root_ir_;
|
||||
std::unique_ptr<ExecutionTree> tree_;
|
||||
bool optimize_; // Flag to enable optional optimization pass
|
||||
|
|
|
@ -142,6 +142,22 @@ BlockQueueStatus_T BlockingQueue::Pop() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
BlockQueueStatus_T BlockingQueue::Clear() {
|
||||
std::unique_lock<std::mutex> locker(mutex_);
|
||||
while (Size() > 0) {
|
||||
std::vector<DataItemGpu> data;
|
||||
auto ret = queue_->Front(&data);
|
||||
if (ret) {
|
||||
return ret;
|
||||
}
|
||||
ret = queue_->Pop();
|
||||
if (ret) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
bool BlockingQueue::Destroy() {
|
||||
if (queue_ != nullptr) {
|
||||
return queue_->Destroy();
|
||||
|
|
|
@ -89,6 +89,7 @@ class BlockingQueue {
|
|||
BlockQueueStatus_T Push(const std::vector<DataItemGpu> &data, unsigned int timeout_in_sec);
|
||||
BlockQueueStatus_T Front(std::vector<DataItemGpu> *data);
|
||||
BlockQueueStatus_T Pop();
|
||||
BlockQueueStatus_T Clear();
|
||||
bool Destroy();
|
||||
size_t Size() { return queue_->Size(); }
|
||||
size_t Capacity() { return queue_->Capacity(); }
|
||||
|
|
|
@ -130,6 +130,14 @@ BlockQueueStatus_T GpuBufferMgr::Pop(unsigned int handle) {
|
|||
return iter->second->Pop();
|
||||
}
|
||||
|
||||
BlockQueueStatus_T GpuBufferMgr::Clear(unsigned int handle) {
|
||||
auto iter = handle_queue_map_.find(handle);
|
||||
if (iter == handle_queue_map_.end()) {
|
||||
return HANDLE_NOT_EXIST;
|
||||
}
|
||||
return iter->second->Clear();
|
||||
}
|
||||
|
||||
void GpuBufferMgr::Close(unsigned int handle) noexcept {
|
||||
if (!handle_queue_map_.count(handle)) {
|
||||
return;
|
||||
|
|
|
@ -84,6 +84,7 @@ class GpuBufferMgr {
|
|||
unsigned int timeout_in_sec);
|
||||
EXPORT BlockQueueStatus_T Front(unsigned int handle, std::vector<DataItemGpu> *data);
|
||||
EXPORT BlockQueueStatus_T Pop(unsigned int handle);
|
||||
EXPORT BlockQueueStatus_T Clear(unsigned int handle);
|
||||
|
||||
EXPORT void set_device_id(int device_id);
|
||||
|
||||
|
|
|
@ -144,6 +144,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
|
|||
${MINDDATA_DIR}/engine/ir/datasetops/repeat_node.cc
|
||||
${MINDDATA_DIR}/engine/ir/datasetops/project_node.cc
|
||||
${MINDDATA_DIR}/engine/ir/datasetops/shuffle_node.cc
|
||||
${MINDDATA_DIR}/engine/ir/datasetops/skip_node.cc
|
||||
${MINDDATA_DIR}/engine/ir/datasetops/source/album_node.cc
|
||||
${MINDDATA_DIR}/engine/ir/datasetops/source/mnist_node.cc
|
||||
${MINDDATA_DIR}/engine/ir/datasetops/source/samplers/distributed_sampler_ir.cc
|
||||
|
@ -161,6 +162,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
|
|||
${MINDDATA_DIR}/engine/datasetops/device_queue_op.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/project_op.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/shuffle_op.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/skip_op.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/pipeline_op.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/batch_op.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/map_op/map_op.cc
|
||||
|
@ -170,6 +172,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full")
|
|||
${MINDDATA_DIR}/engine/datasetops/source/mappable_leaf_op.cc
|
||||
|
||||
${MINDDATA_DIR}/engine/datasetops/source/io_block.cc
|
||||
${MINDDATA_DIR}/engine/opt/pre/add_skip_pass.cc
|
||||
${MINDDATA_DIR}/engine/opt/pre/getter_pass.cc
|
||||
${MINDDATA_DIR}/engine/opt/pre/input_validation_pass.cc
|
||||
${MINDDATA_DIR}/engine/opt/pre/cache_validation_pass.cc
|
||||
|
|
|
@ -86,6 +86,43 @@ OffloadToManualOffloadMode = {
|
|||
True: cde.ManualOffloadMode.ENABLED
|
||||
}
|
||||
|
||||
_train_dataset = None
|
||||
|
||||
|
||||
def _set_training_dataset(dataset):
|
||||
"""
|
||||
Set the dataset to be used when training recovery has occurred.
|
||||
|
||||
Args:
|
||||
dataset: the training dataset or iterator
|
||||
"""
|
||||
global _train_dataset
|
||||
_train_dataset = dataset
|
||||
|
||||
|
||||
def _get_training_dataset():
|
||||
"""
|
||||
Get the dataset to be used when training recovery has occurred.
|
||||
|
||||
Returns:
|
||||
training dataset/iterator
|
||||
"""
|
||||
return _train_dataset
|
||||
|
||||
|
||||
def _reset_training_dataset(step):
|
||||
"""
|
||||
Reset the training dataset to the given step number.
|
||||
|
||||
Args:
|
||||
step (int): Global step number.
|
||||
"""
|
||||
dataset = _get_training_dataset()
|
||||
if dataset is not None:
|
||||
dataset.reset(step)
|
||||
else:
|
||||
raise RuntimeError("Training dataset is not set.")
|
||||
|
||||
|
||||
class Shuffle(str, Enum):
|
||||
"""Specify the shuffle mode.
|
||||
|
@ -3352,6 +3389,9 @@ class _ToDevice:
|
|||
def send(self):
|
||||
self._to_device.Send()
|
||||
|
||||
def reset(self, step):
|
||||
self._to_device.Reset(step)
|
||||
|
||||
def stop_send(self):
|
||||
"""
|
||||
send stop send signal to pipeline, it is used when end of sequence is sent at the epoch end.
|
||||
|
@ -3459,6 +3499,11 @@ class TransferDataset(Dataset):
|
|||
if self._to_device is not None:
|
||||
self._to_device.continue_send()
|
||||
|
||||
def reset(self, step):
|
||||
if self._to_device is not None:
|
||||
logger.info("Reset the dataset pipeline to step " + str(step))
|
||||
self._to_device.reset(step)
|
||||
|
||||
def get_data_info(self):
|
||||
"""
|
||||
Get type and shape of current batch
|
||||
|
|
|
@ -178,6 +178,15 @@ class Iterator:
|
|||
self._getters()
|
||||
return self._col_names
|
||||
|
||||
def reset(self, step):
|
||||
"""
|
||||
Reset the iterator to the given step number.
|
||||
|
||||
Args:
|
||||
step (int): Global step number.
|
||||
"""
|
||||
self._iterator.Reset(step)
|
||||
|
||||
|
||||
class DictIterator(Iterator):
|
||||
"""
|
||||
|
|
|
@ -330,6 +330,10 @@ class DatasetHelper:
|
|||
"""Continue to send data to device at the beginning of epoch."""
|
||||
self.iter.continue_send()
|
||||
|
||||
def reset(self, step):
|
||||
"""Reset the dataset to the provided step."""
|
||||
self.iter.reset(step)
|
||||
|
||||
def get_data_info(self):
|
||||
"""
|
||||
In sink mode, it returns the types and shapes of the current data.
|
||||
|
@ -383,6 +387,7 @@ class _DatasetIter:
|
|||
self.stop_send = dataset.__transfer_dataset__.stop_send
|
||||
self.release = dataset.__transfer_dataset__.release
|
||||
self.continue_send = dataset.__transfer_dataset__.continue_send
|
||||
self.reset = dataset.__transfer_dataset__.reset
|
||||
self.get_data_info = dataset.__transfer_dataset__.get_data_info
|
||||
self.dynamic_min_max_shapes = dataset.dynamic_min_max_shapes
|
||||
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -40,7 +40,7 @@ from ..parallel._cost_model_context import _set_multi_subgraphs
|
|||
from .dataset_helper import DatasetHelper, connect_network_with_dataset
|
||||
from . import amp
|
||||
from ..common.api import _pynative_executor, _cell_graph_executor
|
||||
|
||||
from ..dataset.engine.datasets import _set_training_dataset
|
||||
|
||||
def _transfer_tensor_to_tuple(inputs):
|
||||
"""
|
||||
|
@ -393,6 +393,10 @@ class Model:
|
|||
if dataset_sink_mode:
|
||||
network = connect_network_with_dataset(network, dataset_helper)
|
||||
|
||||
if is_train:
|
||||
_set_training_dataset(dataset_helper) # pylint: disable=W0212
|
||||
|
||||
|
||||
network.set_train(is_train)
|
||||
network.phase = phase
|
||||
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing pipeline Reset
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
||||
def create_np_dataset(size):
|
||||
data = ds.NumpySlicesDataset(list(range(1, size + 1)), shuffle=False)
|
||||
return data
|
||||
|
||||
|
||||
def util(data, num_epochs, failure_point: int, reset_step):
|
||||
size = data.get_dataset_size()
|
||||
expected = []
|
||||
expected_itr = data.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True)
|
||||
for _ in range(num_epochs):
|
||||
for d in expected_itr:
|
||||
expected.append(d)
|
||||
del expected_itr
|
||||
|
||||
actual_before_reset = []
|
||||
itr = data.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True)
|
||||
ds.engine.datasets._set_training_dataset(itr) # pylint: disable=W0212
|
||||
cur_step: int = 0
|
||||
failed = False
|
||||
for _ in range(num_epochs):
|
||||
for d in itr:
|
||||
actual_before_reset.append(d)
|
||||
if cur_step == failure_point:
|
||||
ds.engine.datasets._reset_training_dataset(reset_step) # pylint: disable=W0212
|
||||
failed = True
|
||||
break
|
||||
cur_step += 1
|
||||
if failed:
|
||||
break
|
||||
|
||||
actual_after_reset = []
|
||||
if failed:
|
||||
for _ in range(reset_step // size, num_epochs):
|
||||
for d in itr:
|
||||
actual_after_reset.append(d)
|
||||
|
||||
with pytest.raises(RuntimeError, match="User tries to fetch data beyond the specified number of epochs."):
|
||||
for _ in itr:
|
||||
pass
|
||||
|
||||
for x, y in zip(expected[:failure_point], actual_before_reset):
|
||||
np.testing.assert_array_equal(x, y)
|
||||
|
||||
for x, y in zip(expected[reset_step:], actual_after_reset):
|
||||
np.testing.assert_array_equal(x, y)
|
||||
|
||||
|
||||
def test_reset():
|
||||
"""
|
||||
Feature: dataset recovery
|
||||
Description: Simple test of data pipeline reset feature
|
||||
Expectation: same datasets after reset
|
||||
"""
|
||||
dataset_size = 5
|
||||
num_epochs = 3
|
||||
data = create_np_dataset(size=dataset_size)
|
||||
for failure_point in range(dataset_size * num_epochs):
|
||||
for reset_step in range(dataset_size * num_epochs):
|
||||
util(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_reset()
|
Loading…
Reference in New Issue