forked from mindspore-Ecosystem/mindspore
Implemented Callback for Dataset
implment pause in MapOp, added more to callback add ds_callback - Initial drop of Python DSCallback - Pybind DSCallback - Pybind DSCallback added callback to mapOp - de_pipeline DSCallback - de_pipeline DSCallback add test case, segfault for now fix seg fault - de_pipeline DSCallback remove 1 line update callback test case, now works use builder class for mapOp callback - de_pipeline DSCallback - de_pipeline DSCallback - de_pipeline DSCallback better test case minor fix add comments and minor clean ups get rid of nullptr in MapOp, use other flag instead fix a bug ParseMapOp only takes 1 callback - Added WaitedDSCalabck refactor callback param fix text case incorrect number - added testing fix cpp test case - added testing - revert back lenet changes - cleanup test_callbacks.py - cleanup test_callbacks.py fix CI stage I fix CI stage II fix CI and update epoch counter - add validation - add more testing test_callbacks.py use random data op to do tests adjust when to call EpochBegin/End - add repeat with callback - addressing reviewers' comments - docstring and CI fixes - docstring and CI fixes - docstring and CI fixes - rebase with upstream/master fix cpp test case fix review comments addr review cmts, add test case
This commit is contained in:
parent
89cd465268
commit
78c1aa1d96
|
@ -58,6 +58,7 @@ add_subdirectory(kernels)
|
|||
add_subdirectory(engine)
|
||||
add_subdirectory(api)
|
||||
add_subdirectory(text)
|
||||
add_subdirectory(callback)
|
||||
######################################################################
|
||||
add_dependencies(utils core)
|
||||
add_dependencies(kernels-image core)
|
||||
|
@ -74,6 +75,7 @@ add_dependencies(engine-cache-server core)
|
|||
add_dependencies(engine-perf core)
|
||||
add_dependencies(engine-gnn core)
|
||||
add_dependencies(engine core)
|
||||
add_dependencies(callback core)
|
||||
add_dependencies(text core)
|
||||
add_dependencies(text-kernels core)
|
||||
add_dependencies(cpp-API core)
|
||||
|
@ -87,6 +89,7 @@ endif ()
|
|||
################### Create _c_dataengine Library ######################
|
||||
set(submodules
|
||||
$<TARGET_OBJECTS:core>
|
||||
$<TARGET_OBJECTS:callback>
|
||||
$<TARGET_OBJECTS:utils>
|
||||
$<TARGET_OBJECTS:kernels>
|
||||
$<TARGET_OBJECTS:kernels-image>
|
||||
|
|
|
@ -7,6 +7,7 @@ if (ENABLE_PYTHON)
|
|||
python/bindings.cc
|
||||
python/bindings/dataset/engine/cache/bindings.cc
|
||||
python/bindings/dataset/core/bindings.cc
|
||||
python/bindings/dataset/callback/bindings.cc
|
||||
python/bindings/dataset/kernels/data/bindings.cc
|
||||
python/bindings/dataset/kernels/bindings.cc
|
||||
python/bindings/dataset/engine/datasetops/bindings.cc
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/callback/py_ds_callback.h"
|
||||
#include "minddata/dataset/callback/ds_callback.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(PyDSCallback, 0, ([](const py::module *m) {
|
||||
(void)py::class_<PyDSCallback, std::shared_ptr<PyDSCallback>>(*m, "PyDSCallback")
|
||||
.def(py::init<int32_t>())
|
||||
.def("set_begin", &PyDSCallback::setBegin)
|
||||
.def("set_end", &PyDSCallback::setEnd)
|
||||
.def("set_epoch_begin", &PyDSCallback::setEpochBegin)
|
||||
.def("set_epoch_end", &PyDSCallback::setEpochEnd)
|
||||
.def("set_step_begin", &PyDSCallback::setStepBegin)
|
||||
.def("set_step_end", &PyDSCallback::setStepEnd);
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(CallbackParam, 0, ([](const py::module *m) {
|
||||
(void)py::class_<CallbackParam, std::shared_ptr<CallbackParam>>(*m, "CallbackParam")
|
||||
.def(py::init<int64_t, int64_t, int64_t>())
|
||||
.def_readonly("cur_epoch_num", &CallbackParam::cur_epoch_num_)
|
||||
.def_readonly("cur_step_num_in_epoch", &CallbackParam::cur_epoch_step_num_)
|
||||
.def_readonly("cur_step_num", &CallbackParam::cur_step_num_);
|
||||
}));
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -20,6 +20,7 @@
|
|||
#include <map>
|
||||
|
||||
#include "utils/ms_utils.h"
|
||||
#include "minddata/dataset/callback/py_ds_callback.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/cache/cache_client.h"
|
||||
#include "minddata/dataset/engine/dataset_iterator.h"
|
||||
|
@ -738,8 +739,13 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
|||
(void)map_builder.SetTensorFuncs(std::move(tensor_op_list));
|
||||
} else if (key == "cache") {
|
||||
cache_client = value.cast<std::shared_ptr<CacheClient>>();
|
||||
} else if (key == "callbacks") {
|
||||
std::vector<std::shared_ptr<DSCallback>> callbacks;
|
||||
std::transform(value.begin(), value.end(), std::back_inserter(callbacks),
|
||||
[](py::handle cb) { return cb.cast<std::shared_ptr<PyDSCallback>>(); });
|
||||
(void)map_builder.AddCallbacks(callbacks);
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key);
|
||||
RETURN_STATUS_UNEXPECTED("Error in parsing MapOp: Unhandled key: " + key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
|
||||
|
||||
if (ENABLE_PYTHON)
|
||||
add_library(callback OBJECT
|
||||
callback_manager.cc
|
||||
py_ds_callback.cc
|
||||
)
|
||||
else ()
|
||||
add_library(callback OBJECT
|
||||
callback_manager.cc
|
||||
)
|
||||
endif ()
|
|
@ -0,0 +1,160 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/callback/callback_manager.h"
|
||||
#include "minddata/dataset/callback/ds_callback.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/engine/datasetops/dataset_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
void CallbackManager::AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) {
|
||||
callbacks_.insert(callbacks_.end(), callbacks.begin(), callbacks.end());
|
||||
}
|
||||
|
||||
Status CallbackManager::Init(std::shared_ptr<DatasetOp> op) {
|
||||
RETURN_UNEXPECTED_IF_NULL(op);
|
||||
op_ = op;
|
||||
// turn the flag on if callback is set
|
||||
enabled_ = !callbacks_.empty();
|
||||
|
||||
// error check for each of the callbacks
|
||||
for (auto &cb : callbacks_) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cb->step_size() > 0, "callback step_size needs to be greater than 0.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CallbackManager::Begin(const CallbackParam &cb_param) {
|
||||
RETURN_OK_IF_TRUE(!enabled_);
|
||||
std::vector<size_t> callback_inds;
|
||||
// go through all callback functions to see if each function is needed
|
||||
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||
if (callbacks_[ind]->IsBeginNeeded()) callback_inds.push_back(ind);
|
||||
}
|
||||
// return Status::OK() if no begin is needed
|
||||
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||
|
||||
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||
|
||||
// Now do the actual callback
|
||||
for (size_t ind : callback_inds) {
|
||||
RETURN_IF_NOT_OK(callbacks_[ind]->DSBegin(cb_param));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CallbackManager::EpochBegin(const CallbackParam &cb_param) {
|
||||
RETURN_OK_IF_TRUE(!enabled_);
|
||||
std::vector<size_t> callback_inds;
|
||||
// go through all callback functions to see if each function is needed
|
||||
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||
if (callbacks_[ind]->IsEpochBeginNeeded()) callback_inds.push_back(ind);
|
||||
}
|
||||
// return Status::OK() if no epoch_begin is needed
|
||||
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||
|
||||
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||
|
||||
// Now do the actual callback
|
||||
for (size_t ind : callback_inds) {
|
||||
RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochBegin(cb_param));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CallbackManager::StepBegin(const CallbackParam &cb_param) {
|
||||
RETURN_OK_IF_TRUE(!enabled_);
|
||||
std::vector<size_t> callback_inds;
|
||||
// go through all callback functions to see if each function is needed
|
||||
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||
if (callbacks_[ind]->IsNStepBeginNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
|
||||
callback_inds.push_back(ind);
|
||||
}
|
||||
// return Status::OK() if no step_begin is needed
|
||||
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||
|
||||
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||
|
||||
// Now do the actual callback
|
||||
for (size_t ind : callback_inds) {
|
||||
RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepBegin(cb_param));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CallbackManager::End(const CallbackParam &cb_param) {
|
||||
RETURN_OK_IF_TRUE(!enabled_);
|
||||
std::vector<size_t> callback_inds;
|
||||
// go through all callback functions to see if each function is needed
|
||||
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||
if (callbacks_[ind]->IsEndNeeded()) callback_inds.push_back(ind);
|
||||
}
|
||||
// return Status::OK() if no end is needed
|
||||
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||
|
||||
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||
|
||||
// Now do the actual callback
|
||||
for (size_t ind : callback_inds) {
|
||||
RETURN_IF_NOT_OK(callbacks_[ind]->DSEnd(cb_param));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CallbackManager::EpochEnd(const CallbackParam &cb_param) {
|
||||
RETURN_OK_IF_TRUE(!enabled_);
|
||||
std::vector<size_t> callback_inds;
|
||||
// go through all callback functions to see if each function is needed
|
||||
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||
if (callbacks_[ind]->IsEpochEndNeeded()) callback_inds.push_back(ind);
|
||||
}
|
||||
// return Status::OK() if no epoch_end is needed
|
||||
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||
|
||||
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||
|
||||
// Now do the actual callback
|
||||
for (size_t ind : callback_inds) {
|
||||
RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochEnd(cb_param));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CallbackManager::StepEnd(const CallbackParam &cb_param) {
|
||||
RETURN_OK_IF_TRUE(!enabled_);
|
||||
std::vector<size_t> callback_inds;
|
||||
// go through all callback functions to see if each function is needed
|
||||
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||
if (callbacks_[ind]->IsNStepEndNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
|
||||
callback_inds.push_back(ind);
|
||||
}
|
||||
// return Status::OK() if no step_end is needed
|
||||
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||
|
||||
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||
|
||||
// Now do the actual callback
|
||||
for (size_t ind : callback_inds) {
|
||||
RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepEnd(cb_param));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/callback/ds_callback.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// forward declare to avoid cyclic include of dataset_op.h
|
||||
class DatasetOp;
|
||||
|
||||
/// This class manages all the callbacks that are associated with a single DatasetOp. For now, only MapOp supports this.
|
||||
class CallbackManager {
|
||||
public:
|
||||
/// CallbackManager default constructor. Init needs to be called before using the created instance.
|
||||
CallbackManager() : enabled_(false) {}
|
||||
|
||||
/// \brief
|
||||
/// \param [in] callbacks list of callbacks to perform
|
||||
void AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks);
|
||||
|
||||
/// \brief DatasetOp needs to call Init if it wishes to use callback, Init will set enabled_ to true
|
||||
/// \param[in] op, this pointer is used for Callback Manager to Pause Worker threads
|
||||
/// \return Status
|
||||
Status Init(std::shared_ptr<DatasetOp> op);
|
||||
|
||||
/// \brief callback function called at the start of the first row
|
||||
/// \return Status
|
||||
Status Begin(const CallbackParam &);
|
||||
|
||||
/// \brief callback function called at the start of each epoch
|
||||
/// \return Status
|
||||
Status EpochBegin(const CallbackParam &);
|
||||
|
||||
/// \brief callback function called at the start of each row
|
||||
/// \return Status
|
||||
Status StepBegin(const CallbackParam &);
|
||||
|
||||
/// \brief callback function called after the last row is processed
|
||||
/// \return Status
|
||||
Status End(const CallbackParam &);
|
||||
|
||||
/// \brief callback function called at the end of each epoch
|
||||
/// \return Status
|
||||
Status EpochEnd(const CallbackParam &);
|
||||
|
||||
/// \brief callback function called at the the end of each row
|
||||
/// \return Status
|
||||
Status StepEnd(const CallbackParam &);
|
||||
|
||||
private:
|
||||
bool enabled_; // flag to enable callback, if false, all functions would return immediately
|
||||
std::shared_ptr<DatasetOp> op_; // back pointer to DatasetOp, each DatasetOp has only 1 CallbackManager
|
||||
std::vector<std::shared_ptr<DSCallback>> callbacks_; // list of callbacks the DatasetOp needs to call
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
/// Callback Param is the object a DatasetOp uses to pass run-time information to user defined function.
|
||||
/// This is a prototype for now, more fields will be added
|
||||
class CallbackParam {
|
||||
public:
|
||||
CallbackParam(int64_t epoch_num, int64_t cur_epoch_step, int64_t total_step_num)
|
||||
: cur_epoch_num_(epoch_num), cur_epoch_step_num_(cur_epoch_step), cur_step_num_(total_step_num) {}
|
||||
|
||||
// these are constant public fields for easy access and consistency with python cb_param
|
||||
// the names and orders are consistent with batchInfo
|
||||
const int64_t cur_epoch_num_; // current epoch
|
||||
const int64_t cur_epoch_step_num_; // step number of the current epoch
|
||||
const int64_t cur_step_num_; // step number since the first row
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/callback/callback_param.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class DSCallback {
|
||||
public:
|
||||
/// \brief constructor of DSCallback, this is the base class for all front end specific callbacks
|
||||
/// \param step_size number of steps to call DSNStepBegin()
|
||||
explicit DSCallback(int32_t step_size = 1) : step_size_(step_size) {}
|
||||
|
||||
/// \brief actual callback function for begin, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
virtual Status DSBegin(const CallbackParam &cb_param) = 0;
|
||||
|
||||
/// \brief actual callback function for epoch_begin, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
virtual Status DSEpochBegin(const CallbackParam &cb_param) = 0;
|
||||
|
||||
/// \brief actual callback function for step_begin, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
virtual Status DSNStepBegin(const CallbackParam &cb_param) = 0;
|
||||
|
||||
/// \brief actual callback function for end, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
virtual Status DSEnd(const CallbackParam &cb_param) = 0;
|
||||
|
||||
/// \brief actual callback function epoch_end begin, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
virtual Status DSEpochEnd(const CallbackParam &cb_param) = 0;
|
||||
|
||||
/// \brief actual callback function for step_end, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
virtual Status DSNStepEnd(const CallbackParam &cb_param) = 0;
|
||||
|
||||
/// \brief predicate function, whether begin callback is needed
|
||||
/// \return bool
|
||||
virtual bool IsBeginNeeded() = 0;
|
||||
|
||||
/// \brief predicate function, whether epoch_begin callback is needed
|
||||
/// \return bool
|
||||
virtual bool IsEpochBeginNeeded() = 0;
|
||||
|
||||
/// \brief predicate function, whether step_begin callback is needed
|
||||
/// \return bool
|
||||
virtual bool IsNStepBeginNeeded() = 0;
|
||||
|
||||
/// \brief predicate function, whether end callback is needed
|
||||
/// \return bool
|
||||
virtual bool IsEndNeeded() = 0;
|
||||
|
||||
/// \brief predicate function, whether epoch_end callback is needed
|
||||
/// \return bool
|
||||
virtual bool IsEpochEndNeeded() = 0;
|
||||
|
||||
/// \brief predicate function, whether step_end callback is needed
|
||||
/// \return bool
|
||||
virtual bool IsNStepEndNeeded() = 0;
|
||||
|
||||
/// \brief getter
|
||||
/// \return step_size
|
||||
int32_t step_size() const { return step_size_; }
|
||||
|
||||
protected:
|
||||
int32_t step_size_; // step begin/end will be called every step_size_
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H
|
|
@ -0,0 +1,86 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/callback/callback_manager.h"
|
||||
#include "minddata/dataset/callback/py_ds_callback.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status PyDSCallback::DSBegin(const CallbackParam &cb_param) {
|
||||
return PyDSCallback::ExecutePyfunc(begin_func_, cb_param);
|
||||
}
|
||||
Status PyDSCallback::DSEpochBegin(const CallbackParam &cb_param) {
|
||||
return PyDSCallback::ExecutePyfunc(epoch_begin_func_, cb_param);
|
||||
}
|
||||
Status PyDSCallback::DSNStepBegin(const CallbackParam &cb_param) {
|
||||
return PyDSCallback::ExecutePyfunc(step_begin_func_, cb_param);
|
||||
}
|
||||
Status PyDSCallback::DSEnd(const CallbackParam &cb_param) { return PyDSCallback::ExecutePyfunc(end_func_, cb_param); }
|
||||
|
||||
Status PyDSCallback::DSEpochEnd(const CallbackParam &cb_param) {
|
||||
return PyDSCallback::ExecutePyfunc(epoch_end_func_, cb_param);
|
||||
}
|
||||
Status PyDSCallback::DSNStepEnd(const CallbackParam &cb_param) {
|
||||
return PyDSCallback::ExecutePyfunc(step_end_func_, cb_param);
|
||||
}
|
||||
|
||||
bool PyDSCallback::IsBeginNeeded() { return begin_needed_; }
|
||||
bool PyDSCallback::IsEpochBeginNeeded() { return epoch_begin_needed_; }
|
||||
bool PyDSCallback::IsNStepBeginNeeded() { return step_begin_needed_; }
|
||||
bool PyDSCallback::IsNStepEndNeeded() { return step_end_needed_; }
|
||||
bool PyDSCallback::IsEpochEndNeeded() { return epoch_end_needed_; }
|
||||
bool PyDSCallback::IsEndNeeded() { return end_needed_; }
|
||||
|
||||
Status PyDSCallback::ExecutePyfunc(py::function f, const CallbackParam &cb_param) {
|
||||
{
|
||||
// Acquire Python GIL
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
if (Py_IsInitialized() == 0) {
|
||||
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
|
||||
}
|
||||
f(cb_param);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
void PyDSCallback::setBegin(py::function f) {
|
||||
begin_func_ = f;
|
||||
begin_needed_ = true;
|
||||
}
|
||||
void PyDSCallback::setEnd(py::function f) {
|
||||
end_func_ = f;
|
||||
end_needed_ = true;
|
||||
}
|
||||
void PyDSCallback::setEpochBegin(py::function f) {
|
||||
epoch_begin_func_ = f;
|
||||
epoch_begin_needed_ = true;
|
||||
}
|
||||
void PyDSCallback::setEpochEnd(py::function f) {
|
||||
epoch_end_func_ = f;
|
||||
epoch_end_needed_ = true;
|
||||
}
|
||||
void PyDSCallback::setStepBegin(py::function f) {
|
||||
step_begin_func_ = f;
|
||||
step_begin_needed_ = true;
|
||||
}
|
||||
void PyDSCallback::setStepEnd(py::function f) {
|
||||
step_end_func_ = f;
|
||||
step_end_needed_ = true;
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,130 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/callback/ds_callback.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
class PyDSCallback : public DSCallback {
|
||||
public:
|
||||
/// \brief constructor for PyDSCallback. This callback is for python front end
|
||||
explicit PyDSCallback(int32_t step_size = 1)
|
||||
: DSCallback(step_size),
|
||||
begin_needed_(false),
|
||||
epoch_begin_needed_(false),
|
||||
step_begin_needed_(false),
|
||||
end_needed_(false),
|
||||
epoch_end_needed_(false),
|
||||
step_end_needed_(false) {}
|
||||
|
||||
void setBegin(py::function f);
|
||||
void setEnd(py::function f);
|
||||
void setEpochBegin(py::function f);
|
||||
void setEpochEnd(py::function f);
|
||||
void setStepBegin(py::function f);
|
||||
void setStepEnd(py::function f);
|
||||
|
||||
/// \brief actual callback function for begin, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
Status DSBegin(const CallbackParam &cb_param) override;
|
||||
|
||||
/// \brief actual callback function for epoch_begin, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
Status DSEpochBegin(const CallbackParam &cb_param) override;
|
||||
|
||||
/// \brief actual callback function for step_begin, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
Status DSNStepBegin(const CallbackParam &cb_param) override;
|
||||
|
||||
/// \brief actual callback function for end, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
Status DSEnd(const CallbackParam &cb_param) override;
|
||||
|
||||
/// \brief actual callback function epoch_end begin, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
Status DSEpochEnd(const CallbackParam &cb_param) override;
|
||||
|
||||
/// \brief actual callback function for step_end, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
Status DSNStepEnd(const CallbackParam &cb_param) override;
|
||||
|
||||
/// \brief predicate function, whether begin callback is needed
|
||||
/// \return bool
|
||||
bool IsBeginNeeded() override;
|
||||
|
||||
/// \brief predicate function, whether epoch_begin callback is needed
|
||||
/// \return bool
|
||||
bool IsEpochBeginNeeded() override;
|
||||
|
||||
/// \brief predicate function, whether step_begin callback is needed
|
||||
/// \return bool
|
||||
bool IsNStepBeginNeeded() override;
|
||||
|
||||
/// \brief predicate function, whether end callback is needed
|
||||
/// \return bool
|
||||
bool IsEndNeeded() override;
|
||||
|
||||
/// \brief predicate function, whether epoch_end callback is needed
|
||||
/// \return bool
|
||||
bool IsEpochEndNeeded() override;
|
||||
|
||||
/// \brief predicate function, whether step_end callback is needed
|
||||
/// \return bool
|
||||
bool IsNStepEndNeeded() override;
|
||||
|
||||
/// \brief helper function to acquire GIL then execute a pyfunc
|
||||
/// \param f the python function
|
||||
/// \param cb_param
|
||||
/// \return Status
|
||||
static Status ExecutePyfunc(py::function f, const CallbackParam &cb_param);
|
||||
|
||||
private:
|
||||
py::function begin_func_;
|
||||
py::function epoch_begin_func_;
|
||||
py::function step_begin_func_;
|
||||
py::function end_func_;
|
||||
py::function epoch_end_func_;
|
||||
py::function step_end_func_;
|
||||
|
||||
bool begin_needed_;
|
||||
bool epoch_begin_needed_;
|
||||
bool step_begin_needed_;
|
||||
bool end_needed_;
|
||||
bool epoch_end_needed_;
|
||||
bool step_end_needed_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H
|
|
@ -21,6 +21,8 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/callback/callback_manager.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
@ -358,6 +360,14 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// \return boolean returns true if it's last iteration
|
||||
bool IsLastIteration() { return op_total_repeats_ == op_current_repeats_ + 1; }
|
||||
|
||||
/// This function is only intended to be called by CallbackManager within the master thread of ParallelOp
|
||||
/// The expected behavior is this, when this function is invoked, this function will block until all the workers
|
||||
/// have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master.
|
||||
/// They would automatically wait on the QueueList when they are done. Hence, for now, a Unpause() function is not
|
||||
/// needed. Only parallelOp needs to override this function.
|
||||
/// \return Status
|
||||
virtual Status PauseFromMaster() { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
/// \brief Removes a parent operator from this operator
|
||||
/// \notes External callers do not have access to this function
|
||||
|
@ -394,6 +404,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
std::unique_ptr<DbConnector> out_connector_; // Output Connector
|
||||
std::unordered_map<std::string, int32_t> column_name_id_map_; // Mapping between col index and col name
|
||||
std::mutex column_name_map_mutex_; // For protecting shared access to the column map
|
||||
CallbackManager callback_manager_; // Manages callbacks associated with a DatasetOp
|
||||
|
||||
private:
|
||||
/// Sets the operator id.
|
||||
|
|
|
@ -15,25 +15,23 @@
|
|||
*/
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
|
||||
#include "minddata/dataset/callback/callback_param.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/data_buffer.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/map_op/cpu_map_job.h"
|
||||
#include "minddata/dataset/engine/datasetops/map_op/gpu_map_job.h"
|
||||
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "minddata/dataset/util/task_manager.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -58,6 +56,7 @@ Status MapOp::Builder::Build(std::shared_ptr<MapOp> *ptr) {
|
|||
RETURN_IF_NOT_OK(sanityCheck());
|
||||
*ptr = std::make_shared<MapOp>(std::move(build_in_col_names_), std::move(build_out_col_names_),
|
||||
std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_);
|
||||
(*ptr)->callback_manager_.AddCallbacks(std::move(builder_callbacks_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -164,7 +163,10 @@ Status MapOp::GenerateWorkerJob(const std::unique_ptr<MapWorkerJob> *worker_job)
|
|||
Status MapOp::operator()() {
|
||||
// Create and register the local queues.
|
||||
local_queues_.Init(num_workers_, oc_queue_size_);
|
||||
// init callback
|
||||
RETURN_IF_NOT_OK(callback_manager_.Init(shared_from_this()));
|
||||
Status rc = local_queues_.Register(tree_->AllTasks());
|
||||
RETURN_IF_NOT_OK(master_pause_wp_.Register(tree_->AllTasks()));
|
||||
if (rc.IsError()) {
|
||||
TaskManager::FindMe()->Post();
|
||||
return rc;
|
||||
|
@ -175,28 +177,51 @@ Status MapOp::operator()() {
|
|||
// Synchronize with TaskManager
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(rc);
|
||||
// num_buffers received, including eoe, num_epoch, num_step of current epoch
|
||||
int64_t num_buf = 0, ep_step = 0, total_step = 0;
|
||||
RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step)));
|
||||
|
||||
int64_t que_id = 0;
|
||||
std::unique_ptr<DataBuffer> buff;
|
||||
bool is_eof = false;
|
||||
// Drain output connector of the previous op, generate jobs for worker threads, and distribute them via local queues
|
||||
// Stop when all worker threads are finished (received EOF)
|
||||
while (!is_eof) {
|
||||
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
|
||||
while (!buff->eof()) {
|
||||
if (op_current_repeats_ % op_num_repeats_per_epoch() == 0) {
|
||||
RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||
}
|
||||
while (!buff->eoe()) {
|
||||
ep_step++;
|
||||
total_step++;
|
||||
// Create an empty map worker job to be populated by a databuffer and map jobs
|
||||
RETURN_IF_NOT_OK(callback_manager_.StepBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff));
|
||||
|
||||
// Populate map worker job for a worker to execute
|
||||
RETURN_IF_NOT_OK(GenerateWorkerJob(&worker_job));
|
||||
|
||||
// Push map worker job to the corresponding worker's queue
|
||||
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job)));
|
||||
RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
|
||||
}
|
||||
|
||||
// send the eoe buffer to worker
|
||||
|
||||
// reset epoch_step when a new epoch is about to start
|
||||
if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) {
|
||||
RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||
ep_step = 0;
|
||||
}
|
||||
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff));
|
||||
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job)));
|
||||
UpdateRepeatAndEpochCounter();
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
|
||||
is_eof = buff->eof();
|
||||
|
||||
// Create an empty map worker job to be populated by a databuffer and map jobs
|
||||
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>();
|
||||
worker_job->databuffer = std::move(buff);
|
||||
|
||||
// Populate map worker job for a worker to execute
|
||||
RETURN_IF_NOT_OK(GenerateWorkerJob(&worker_job));
|
||||
|
||||
// Push map worker job to the corresponding worker's queue
|
||||
RETURN_IF_NOT_OK(local_queues_[que_id]->Add(std::move(worker_job)));
|
||||
que_id = (que_id + 1) % num_workers_;
|
||||
}
|
||||
|
||||
// the last eoe increments the eoe count by 1, but this shouldn't be reflected on End() callback
|
||||
// RETURN_IF_NOT_OK(callback_manager_.End(CallbackParam(op_current_epochs_, ep_step, total_step)));
|
||||
// handle eof logic
|
||||
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff));
|
||||
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -213,25 +238,19 @@ Status MapOp::WorkerEntry(int32_t worker_id) {
|
|||
// Fetch next data buffer and map job list
|
||||
RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list));
|
||||
|
||||
// Sanity check the databuffer.
|
||||
// Special case: if there's more threads than buffers, some threads simply get the final control
|
||||
// messages (eoe/eof), and so they will not perform the check.
|
||||
if (!in_buffer->eoe() && !in_buffer->eof()) {
|
||||
int32_t num_rows = in_buffer->NumRows();
|
||||
int32_t num_cols = in_buffer->NumCols();
|
||||
if (num_rows == 0 || num_cols == 0) {
|
||||
RETURN_STATUS_UNEXPECTED("MapOp is getting an empty DataBuffer.");
|
||||
}
|
||||
}
|
||||
|
||||
// Now that init work is done, drop into the main fetching loop.
|
||||
// Map op does not use child iterator, and it needs to manually handle eoe and eof's itself
|
||||
// rather than use the base-class defaults.
|
||||
while (true) {
|
||||
// Handle EOE and EOF ourselves. Implicit eoe/eof handling in GetNextInput does not work
|
||||
// with Performance Mode design.
|
||||
if (in_buffer->eoe()) {
|
||||
UpdateRepeatAndEpochCounter();
|
||||
// handle the pause logic. Pause is triggered when an buffer id of -1 with no special flag and no row is received
|
||||
if (in_buffer->id() == -1 && in_buffer->buffer_flags() == DataBuffer::kDeBFlagNone && in_buffer->NumRows() == 0) {
|
||||
// when worker receives the signal from master thread, it increments a atomic int
|
||||
// the last guy who increments the counter, wakes up master thread
|
||||
if (++num_workers_paused_ == num_workers_) master_pause_wp_.Set();
|
||||
// this will block the worker until master thread gives it a new work
|
||||
RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list));
|
||||
continue;
|
||||
} else if (in_buffer->eoe()) {
|
||||
// Calling base class EoeReceived to forward eoe buffer.
|
||||
RETURN_IF_NOT_OK(EoeReceived(worker_id));
|
||||
// Fetch next data buffer and map job list
|
||||
|
@ -243,6 +262,7 @@ Status MapOp::WorkerEntry(int32_t worker_id) {
|
|||
break;
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(in_buffer->NumRows() * in_buffer->NumCols() != 0, "MapOp got an empty DataBuffer.");
|
||||
std::unique_ptr<TensorQTable> new_tensor_table(std::make_unique<TensorQTable>());
|
||||
// Perform the compute function of TensorOp(s) and store the result in new_tensor_table.
|
||||
RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), new_tensor_table.get(), job_list));
|
||||
|
@ -281,9 +301,9 @@ Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_tabl
|
|||
std::vector<TensorRow> result_table;
|
||||
// Executing the list of jobs
|
||||
for (size_t i = 0; i < job_list.size(); i++) {
|
||||
// Executre MapJob.
|
||||
// Execute MapJob.
|
||||
RETURN_IF_NOT_OK(job_list[i]->Run(job_input_table, &result_table));
|
||||
// Assign the pocessed data as an input for the next job processing, except for the last TensorOp in the list.
|
||||
// Assign the processed data as an input for the next job processing, except for the last TensorOp in the list.
|
||||
if (i + 1 < job_list.size()) {
|
||||
job_input_table = std::move(result_table);
|
||||
}
|
||||
|
@ -428,5 +448,20 @@ Status MapOp::Accept(NodePass *p, bool *modified) {
|
|||
// Downcast shared pointer then call visitor
|
||||
return p->RunOnNode(shared_from_base<MapOp>(), modified);
|
||||
}
|
||||
|
||||
Status MapOp::PauseFromMaster() {
|
||||
// reset num_paused workers to 0
|
||||
num_workers_paused_ = 0;
|
||||
for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
|
||||
// a special buffer (id=-1, empty, none flag) is used to signal that worker needs to pause.
|
||||
RETURN_IF_NOT_OK(local_queues_[wkr_id]->Add(
|
||||
std::make_unique<MapWorkerJob>(std::make_unique<DataBuffer>(-1, DataBuffer::kDeBFlagNone))));
|
||||
}
|
||||
// wait until all workers are done processing their work in local_queue_
|
||||
RETURN_IF_NOT_OK(master_pause_wp_.Wait());
|
||||
// clear the WaitPost for the next Wait()
|
||||
master_pause_wp_.Clear();
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,15 +16,19 @@
|
|||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_MAP_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_MAP_OP_H_
|
||||
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/callback/ds_callback.h"
|
||||
#include "minddata/dataset/engine/datasetops/map_op/map_job.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/queue.h"
|
||||
#include "minddata/dataset/engine/datasetops/map_op/map_job.h"
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -108,6 +112,13 @@ class MapOp : public ParallelOp {
|
|||
return *this;
|
||||
}
|
||||
|
||||
// Setter method.
|
||||
// @return Builder setter method returns reference to the builder.
|
||||
Builder &AddCallbacks(const std::vector<std::shared_ptr<DSCallback>> &callbacks) {
|
||||
builder_callbacks_.insert(builder_callbacks_.end(), callbacks.begin(), callbacks.end());
|
||||
return *this;
|
||||
}
|
||||
|
||||
// The builder "build" method creates the final object.
|
||||
// @param ptr The shared_ptr to the new MapOp object
|
||||
// @return Status
|
||||
|
@ -116,6 +127,7 @@ class MapOp : public ParallelOp {
|
|||
private:
|
||||
std::vector<std::string> build_in_col_names_;
|
||||
std::vector<std::string> build_out_col_names_;
|
||||
std::vector<std::shared_ptr<DSCallback>> builder_callbacks_;
|
||||
std::vector<std::shared_ptr<TensorOp>> build_tensor_funcs_;
|
||||
int32_t build_num_workers_;
|
||||
int32_t build_op_connector_size_;
|
||||
|
@ -186,6 +198,7 @@ class MapOp : public ParallelOp {
|
|||
// A unit of job for map worker thread.
|
||||
// MapWorkerJob holds a list of MapJob where each MapJob can be a CpuMapJob, GpuMapJob or DvppMapJob.
|
||||
struct MapWorkerJob {
|
||||
explicit MapWorkerJob(std::unique_ptr<DataBuffer> db) : databuffer(std::move(db)) {}
|
||||
std::vector<std::shared_ptr<MapJob>> jobs;
|
||||
std::unique_ptr<DataBuffer> databuffer;
|
||||
};
|
||||
|
@ -215,6 +228,12 @@ class MapOp : public ParallelOp {
|
|||
// Indices of the columns to process.
|
||||
std::vector<size_t> to_process_indices_;
|
||||
|
||||
// wait post used to perform the pausing logic in MapOp
|
||||
WaitPost master_pause_wp_;
|
||||
|
||||
// count number of workers that have signaled master
|
||||
std::atomic_int num_workers_paused_;
|
||||
|
||||
// Private function for worker/thread to loop continuously. It comprises the main
|
||||
// logic of MapOp: getting the data from previous Op, validating user specified column names,
|
||||
// applying a list of TensorOps to each of the data, process the results and then
|
||||
|
@ -247,6 +266,13 @@ class MapOp : public ParallelOp {
|
|||
// Private function for initializing private variables such as in_columns_, out_columns_.
|
||||
// @return - Status
|
||||
Status InitPrivateVariable(std::unordered_map<std::string, int32_t> *col_name_id_map);
|
||||
|
||||
// This function should only be called from master thread. It intends to suspend the operation of all workers and
|
||||
// have them wait on the QueueList. Master thread would send a token to each worker then wait on a WaitPost.
|
||||
// Workers upon receiving the suspension token from master thread, increment an atomic count, the last worker
|
||||
// who does the increment wakes up the master.
|
||||
// @return - Status
|
||||
Status PauseFromMaster() override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,7 +34,7 @@ class Semaphore {
|
|||
/// \brief Decrement the internal counter. Will be blocked if the value is 0.
|
||||
/// \return Error code. Can get interrupt.
|
||||
Status P();
|
||||
/// \brief Increment the internal counter. Wakeup on of the watiers if any.
|
||||
/// \brief Increment the internal counter. Wakeup on of the waiters if any.
|
||||
void V();
|
||||
/// \brief Peek the internal value
|
||||
/// \return The internal value
|
||||
|
|
|
@ -59,6 +59,13 @@ namespace dataset {
|
|||
} \
|
||||
} while (false)
|
||||
|
||||
#define RETURN_OK_IF_TRUE(_condition) \
|
||||
do { \
|
||||
if (_condition) { \
|
||||
return Status::OK(); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
enum class StatusCode : char {
|
||||
kOK = 0,
|
||||
kOutOfMemory = 1,
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""init file for python callback"""
|
||||
from .ds_callback import DSCallback, WaitedDSCallback
|
||||
|
||||
__all__ = ["DSCallback", "WaitedDSCallback"]
|
|
@ -0,0 +1,232 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Python callback class
|
||||
"""
|
||||
import threading
|
||||
from mindspore._c_dataengine import PyDSCallback
|
||||
from mindspore.train.callback import Callback
|
||||
from .validators import check_callback
|
||||
|
||||
|
||||
class DSCallback:
|
||||
"""
|
||||
Abstract base class used to build a dataset callback class.
|
||||
|
||||
Args:
|
||||
step_size (int, optional): The number of steps before the step_begin and step_end are called (Default=1).
|
||||
|
||||
Examples:
|
||||
>>> class PrintInfo(DSCallback):
|
||||
>>> def ds_epoch_end(self, ds_run_context):
|
||||
>>> print(cb_params.cur_epoch_num)
|
||||
>>> print(cb_params.cur_step_num)
|
||||
>>>
|
||||
>>> data = data.map(operations=op, callbacks=PrintInfo())
|
||||
"""
|
||||
|
||||
@check_callback
|
||||
def __init__(self, step_size=1):
|
||||
self.step_size = step_size
|
||||
|
||||
def ds_begin(self, ds_run_context):
|
||||
"""
|
||||
Called before the data pipeline is started.
|
||||
|
||||
Args:
|
||||
ds_run_context (RunContext): Include some information of the pipeline.
|
||||
"""
|
||||
|
||||
def ds_epoch_begin(self, ds_run_context):
|
||||
"""
|
||||
Called before a new epoch is started.
|
||||
|
||||
Args:
|
||||
ds_run_context (RunContext): Include some information of the pipeline.
|
||||
"""
|
||||
|
||||
def ds_epoch_end(self, ds_run_context):
|
||||
"""
|
||||
Called after an epoch is finished.
|
||||
|
||||
Args:
|
||||
ds_run_context (RunContext): Include some information of the pipeline.
|
||||
"""
|
||||
|
||||
def ds_step_begin(self, ds_run_context):
|
||||
"""
|
||||
Called before n steps are started.
|
||||
|
||||
Args:
|
||||
ds_run_context (RunContext): Include some information of the pipeline.
|
||||
"""
|
||||
|
||||
def ds_step_end(self, ds_run_context):
|
||||
"""
|
||||
Called after n steps are finished.
|
||||
|
||||
Args:
|
||||
ds_run_context (RunContext): Include some information of the pipeline.
|
||||
"""
|
||||
|
||||
def create_runtime_obj(self):
|
||||
"""
|
||||
Creates a runtime (C++) object from the callback methods defined by the user.
|
||||
|
||||
Returns: _c_dataengine.PyDSCallback
|
||||
"""
|
||||
c_cb = PyDSCallback(self.step_size)
|
||||
at_least_one = False
|
||||
|
||||
if self.__class__.ds_begin != DSCallback.ds_begin:
|
||||
c_cb.set_begin(self.ds_begin)
|
||||
at_least_one = True
|
||||
|
||||
if self.__class__.ds_epoch_begin != DSCallback.ds_epoch_begin:
|
||||
c_cb.set_epoch_begin(self.ds_epoch_begin)
|
||||
at_least_one = True
|
||||
if self.__class__.ds_epoch_end != DSCallback.ds_epoch_end:
|
||||
c_cb.set_epoch_end(self.ds_epoch_end)
|
||||
at_least_one = True
|
||||
|
||||
if self.__class__.ds_step_begin != DSCallback.ds_step_begin:
|
||||
c_cb.set_step_begin(self.ds_step_begin)
|
||||
at_least_one = True
|
||||
if self.__class__.ds_step_end != DSCallback.ds_step_end:
|
||||
c_cb.set_step_end(self.ds_step_end)
|
||||
at_least_one = True
|
||||
|
||||
if not at_least_one:
|
||||
raise AttributeError("Provided Callback class did not override any of the 6 callback methods.")
|
||||
|
||||
return c_cb
|
||||
|
||||
|
||||
class WaitedDSCallback(Callback, DSCallback):
|
||||
"""
|
||||
Abstract base class used to build a dataset callback class that are synchronized with the training callback.
|
||||
|
||||
This class can be used to execute a user defined logic right after the previous step or epoch.
|
||||
For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters.
|
||||
|
||||
Examples:
|
||||
>>> my_cb = MyWaitedCallback(32)
|
||||
>>> data = data.map(operations=AugOp(), callbacks=my_cb)
|
||||
>>> data = data.batch(32)
|
||||
>>> # define the model
|
||||
>>> model.train(epochs, data, callbacks=[my_cb])
|
||||
|
||||
|
||||
Args:
|
||||
step_size: the number of rows in each step.
|
||||
Usually the step size will be equal to the batch size (Default=1)
|
||||
"""
|
||||
|
||||
def __init__(self, step_size=1):
|
||||
super().__init__()
|
||||
self.step_size = step_size
|
||||
self.step_event = threading.Event()
|
||||
self.step_run_context = None
|
||||
|
||||
self.epoch_event = threading.Event()
|
||||
self.epoch_run_context = None
|
||||
|
||||
def sync_epoch_begin(self, train_run_context, ds_run_context):
|
||||
"""
|
||||
Called before a new dataset epoch is started and after the previous training epoch is ended.
|
||||
|
||||
Args:
|
||||
train_run_context: Include some information of the model with feedback from the previous epoch.
|
||||
ds_run_context: Include some information of the dataset pipeline.
|
||||
"""
|
||||
|
||||
def sync_step_begin(self, train_run_context, ds_run_context):
|
||||
"""
|
||||
Called before a new dataset step is started and after the previous training step is ended.
|
||||
|
||||
Args:
|
||||
train_run_context: Include some information of the model with feedback from the previous step.
|
||||
ds_run_context: Include some information of the dataset pipeline.
|
||||
"""
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""
|
||||
Internal method, do not call/override. Defines epoch_end of Callback to release the wait in ds_epoch_begin.
|
||||
|
||||
Args:
|
||||
run_context: Include some information of the model.
|
||||
"""
|
||||
self.epoch_run_context = run_context
|
||||
self.epoch_event.set()
|
||||
self.epoch_event.clear()
|
||||
|
||||
def ds_epoch_begin(self, ds_run_context):
|
||||
"""
|
||||
Internal method, do not call/override. Defines ds_epoch_begin of DSCallback to wait for MS epoch_end callback.
|
||||
|
||||
Args:
|
||||
ds_run_context: Include some information of the pipeline.
|
||||
"""
|
||||
if ds_run_context.cur_epoch_num > 1:
|
||||
if self.epoch_run_context is None:
|
||||
self.epoch_event.wait()
|
||||
self.sync_epoch_begin(self.epoch_run_context, ds_run_context)
|
||||
self.epoch_run_context = None
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Internal method, do not call/override. Defines step_end of Callback to release the wait in ds_step_begin.
|
||||
|
||||
Args:
|
||||
run_context: Include some information of the model.
|
||||
"""
|
||||
self.step_run_context = run_context
|
||||
self.step_event.set()
|
||||
self.step_event.clear()
|
||||
|
||||
def ds_step_begin(self, ds_run_context):
|
||||
"""
|
||||
Internal method, do not call/override. Defines ds_step_begin of DSCallback to wait for MS step_end callback.
|
||||
|
||||
Args:
|
||||
ds_run_context: Include some information of the pipeline.
|
||||
"""
|
||||
if ds_run_context.cur_step_num > self.step_size:
|
||||
if self.step_run_context is None:
|
||||
self.step_event.wait()
|
||||
self.sync_step_begin(self.step_run_context, ds_run_context)
|
||||
self.step_run_context = None
|
||||
|
||||
def create_runtime_obj(self):
|
||||
"""
|
||||
Creates a runtime (C++) object from the callback methods defined by the user. This method is internal.
|
||||
|
||||
Returns: _c_dataengine.PyDSCallback
|
||||
"""
|
||||
c_cb = PyDSCallback(self.step_size)
|
||||
at_least_one = False
|
||||
|
||||
if self.__class__.sync_step_begin != WaitedDSCallback.sync_step_begin:
|
||||
c_cb.set_step_begin(self.ds_step_begin)
|
||||
at_least_one = True
|
||||
|
||||
if self.__class__.sync_epoch_begin != WaitedDSCallback.sync_epoch_begin:
|
||||
c_cb.set_epoch_begin(self.ds_epoch_begin)
|
||||
at_least_one = True
|
||||
|
||||
if not at_least_one:
|
||||
raise AttributeError("Provided Callback class did not override any of the 2 callback methods.")
|
||||
|
||||
return c_cb
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License foNtest_resr the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""
|
||||
Built-in validators.
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from ..core.validator_helpers import parse_user_args, check_pos_int32
|
||||
|
||||
|
||||
def check_callback(method):
|
||||
"""check the input arguments of DSCallback."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[step_size], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_pos_int32(step_size, "step_size")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
|
@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
||||
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
|
||||
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
|
||||
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset,\
|
||||
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, \
|
||||
check_paddeddataset
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE
|
||||
|
@ -395,7 +395,7 @@ class Dataset:
|
|||
|
||||
@check_map
|
||||
def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
|
||||
num_parallel_workers=None, python_multiprocessing=False, cache=None):
|
||||
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None):
|
||||
"""
|
||||
Apply each operation in operations to this dataset.
|
||||
|
||||
|
@ -438,6 +438,8 @@ class Dataset:
|
|||
option could be beneficial if the python operation is computational heavy (default=False).
|
||||
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
|
||||
The cache feature is under development and is not recommended.
|
||||
callbacks: (DSCallback, list[DSCallback], optional): list of Dataset callbacks to be called (Default=None).
|
||||
|
||||
|
||||
Returns:
|
||||
MapDataset, dataset after mapping operation.
|
||||
|
@ -552,7 +554,7 @@ class Dataset:
|
|||
>>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order)
|
||||
"""
|
||||
return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers,
|
||||
python_multiprocessing, cache)
|
||||
python_multiprocessing, cache, callbacks)
|
||||
|
||||
@check_filter
|
||||
def filter(self, predicate, input_columns=None, num_parallel_workers=1):
|
||||
|
@ -1548,6 +1550,7 @@ class DatasetOp(Dataset):
|
|||
return self.children[0].get_class_indexing()
|
||||
raise NotImplementedError("Dataset {} has not supported api get_class_indexing yet.".format(type(self)))
|
||||
|
||||
|
||||
class BucketBatchByLengthDataset(DatasetOp):
|
||||
"""
|
||||
The result of applying BucketBatchByLength operator to the input dataset.
|
||||
|
@ -1964,14 +1967,14 @@ class MapDataset(DatasetOp):
|
|||
option could be beneficial if the python operation is computational heavy (default=False).
|
||||
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
|
||||
The cache feature is under development and is not recommended.
|
||||
|
||||
callbacks: (DSCallback, list[DSCallback], optional): list of Dataset callbacks to be called (Default=None)
|
||||
|
||||
Raises:
|
||||
ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None,
|
||||
num_parallel_workers=None, python_multiprocessing=False, cache=None):
|
||||
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None):
|
||||
super().__init__(num_parallel_workers)
|
||||
self.children.append(input_dataset)
|
||||
if input_columns is not None and not isinstance(input_columns, list):
|
||||
|
@ -1996,6 +1999,11 @@ class MapDataset(DatasetOp):
|
|||
self.python_multiprocessing = python_multiprocessing
|
||||
self.process_pool = None
|
||||
|
||||
if callbacks is not None and not isinstance(callbacks, list):
|
||||
callbacks = [callbacks]
|
||||
|
||||
self.callbacks = callbacks
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
args["input_columns"] = self.input_columns
|
||||
|
@ -2003,6 +2011,9 @@ class MapDataset(DatasetOp):
|
|||
args["output_columns"] = self.output_columns
|
||||
args["columns_order"] = self.columns_order
|
||||
args["cache"] = self.cache.cache_client if self.cache is not None else None
|
||||
|
||||
if self.callbacks is not None:
|
||||
args["callbacks"] = [cb.create_runtime_obj() for cb in self.callbacks]
|
||||
return args
|
||||
|
||||
def get_dataset_size(self):
|
||||
|
@ -2034,6 +2045,7 @@ class MapDataset(DatasetOp):
|
|||
new_op.cache = copy.deepcopy(self.cache, memodict)
|
||||
new_op.operations = self.operations
|
||||
new_op.dataset_size = self.dataset_size
|
||||
new_op.callbacks = self.callbacks
|
||||
return new_op
|
||||
|
||||
# Iterator bootstrap will be called on iterator construction.
|
||||
|
@ -2393,7 +2405,6 @@ class ConcatDataset(DatasetOp):
|
|||
self._children_start_end_index_[index][0] = cumulative_samples_nums
|
||||
self._children_start_end_index_[index][1] = tem_value % sampler.num_shards
|
||||
|
||||
|
||||
tem_sampler = copy.deepcopy(sampler)
|
||||
tem_sampler.set_offset(cumulative_samples_nums)
|
||||
child.sampler = tem_sampler
|
||||
|
@ -2556,7 +2567,7 @@ class RangeDataset(MappableDataset):
|
|||
|
||||
def get_dataset_size(self):
|
||||
if self.dataset_size is None:
|
||||
self.dataset_size = math.ceil((self.stop - self.start)/self.step)
|
||||
self.dataset_size = math.ceil((self.stop - self.start) / self.step)
|
||||
return self.dataset_size
|
||||
|
||||
|
||||
|
@ -3423,7 +3434,7 @@ class GeneratorDataset(MappableDataset):
|
|||
if not self.num_shards:
|
||||
self.dataset_size = len(self.source)
|
||||
else:
|
||||
self.dataset_size = math.ceil(len(self.source)/self.num_shards)
|
||||
self.dataset_size = math.ceil(len(self.source) / self.num_shards)
|
||||
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
|
||||
|
@ -5428,6 +5439,7 @@ class NumpySlicesDataset(GeneratorDataset):
|
|||
num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
|
||||
num_shards=num_shards, shard_id=shard_id)
|
||||
|
||||
|
||||
class _PaddedDataset:
|
||||
"""
|
||||
Mainly for combining false samples provided by users into a dataset.
|
||||
|
@ -5435,6 +5447,7 @@ class _PaddedDataset:
|
|||
Args:
|
||||
padded_samples (list(dict)): the data provided by user to added to initial Dataset
|
||||
"""
|
||||
|
||||
def __init__(self, padded_samples):
|
||||
self.column_names = list(padded_samples[0].keys())
|
||||
self.padded_samples = padded_samples
|
||||
|
@ -5445,6 +5458,7 @@ class _PaddedDataset:
|
|||
def __len__(self):
|
||||
return len(self.padded_samples)
|
||||
|
||||
|
||||
class PaddedDataset(GeneratorDataset):
|
||||
"""
|
||||
Create a dataset with fake data provided by user. Mainly used to add to the original data set
|
||||
|
@ -5463,6 +5477,7 @@ class PaddedDataset(GeneratorDataset):
|
|||
>>> data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)}]
|
||||
>>> ds1 = ds.PaddedDataset(data1)
|
||||
"""
|
||||
|
||||
@check_paddeddataset
|
||||
def __init__(self, padded_samples):
|
||||
dataset = _PaddedDataset(padded_samples)
|
||||
|
|
|
@ -23,6 +23,7 @@ from functools import wraps
|
|||
|
||||
import numpy as np
|
||||
from mindspore._c_expression import typing
|
||||
from mindspore.dataset.callback import DSCallback
|
||||
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
|
||||
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
|
||||
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \
|
||||
|
@ -31,6 +32,7 @@ from ..core.validator_helpers import parse_user_args, type_check, type_check_lis
|
|||
from . import datasets
|
||||
from . import samplers
|
||||
from . import cache_client
|
||||
from .. import callback
|
||||
|
||||
|
||||
def check_imagefolderdatasetv2(method):
|
||||
|
@ -247,6 +249,7 @@ def check_celebadataset(method):
|
|||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_save(method):
|
||||
"""A wrapper that wrap a parameter checker to the save op."""
|
||||
|
||||
|
@ -257,7 +260,7 @@ def check_save(method):
|
|||
nreq_param_int = ['num_files']
|
||||
nreq_param_str = ['file_name', 'file_type']
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
if(param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000):
|
||||
if (param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000):
|
||||
raise ValueError("num_files should between {} and {}.".format(1, 1000))
|
||||
validate_dataset_param_value(nreq_param_str, param_dict, str)
|
||||
if param_dict.get('file_type') != 'mindrecord':
|
||||
|
@ -265,6 +268,8 @@ def check_save(method):
|
|||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_minddataset(method):
|
||||
"""A wrapper that wraps a parameter checker to the original Dataset(MindDataset)."""
|
||||
|
||||
|
@ -362,6 +367,7 @@ def check_generatordataset(method):
|
|||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_random_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker to the original Dataset(RandomDataset)."""
|
||||
|
||||
|
@ -545,7 +551,8 @@ def check_map(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing, cache], _ = \
|
||||
[input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing, cache,
|
||||
callbacks], _ = \
|
||||
parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_columns = ['input_columns', 'output_columns']
|
||||
|
@ -558,9 +565,17 @@ def check_map(method):
|
|||
if cache is not None:
|
||||
type_check(cache, (cache_client.DatasetCache,), "cache")
|
||||
|
||||
if callbacks is not None:
|
||||
if isinstance(callbacks, (list, tuple)):
|
||||
type_check_list(callbacks, (callback.DSCallback,), "callbacks")
|
||||
else:
|
||||
type_check(callbacks, (callback.DSCallback,), "callbacks")
|
||||
|
||||
for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]):
|
||||
if param is not None:
|
||||
check_columns(param, param_name)
|
||||
if callbacks is not None:
|
||||
type_check(callbacks, (list, DSCallback), "callbacks")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ SET(DE_UT_SRCS
|
|||
bounding_box_augment_op_test.cc
|
||||
arena_test.cc
|
||||
btree_test.cc
|
||||
callback_test.cc
|
||||
center_crop_op_test.cc
|
||||
channel_swap_test.cc
|
||||
circular_pool_test.cc
|
||||
|
|
|
@ -0,0 +1,301 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <list>
|
||||
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/callback/ds_callback.h"
|
||||
#include "minddata/dataset/core/client.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
|
||||
#include "minddata/dataset/kernels/data/no_op.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::LogStream;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace test {
|
||||
|
||||
std::shared_ptr<ExecutionTree> BuildTree(std::vector<std::shared_ptr<DatasetOp>> ops) {
|
||||
std::shared_ptr<ExecutionTree> tree = std::make_shared<ExecutionTree>();
|
||||
Status rc;
|
||||
for (int i = 0; i < ops.size(); i++) {
|
||||
rc = tree->AssociateNode(ops[i]);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
if (i > 0) {
|
||||
rc = ops[i]->AddChild(ops[i - 1]);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
}
|
||||
if (i == ops.size() - 1) {
|
||||
rc = tree->AssignRoot(ops[i]);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
}
|
||||
}
|
||||
return tree;
|
||||
}
|
||||
|
||||
class TestCallback : public DSCallback {
|
||||
public:
|
||||
TestCallback(int32_t step_size)
|
||||
: DSCallback(step_size),
|
||||
begin_(true),
|
||||
epoch_begin_(true),
|
||||
step_begin_(true),
|
||||
end_(true),
|
||||
epoch_end_(true),
|
||||
step_end_(true) {
|
||||
all_names_.reserve(32);
|
||||
all_step_nums_.reserve(32);
|
||||
all_ep_nums_.reserve(32);
|
||||
}
|
||||
|
||||
Status DSBegin(const CallbackParam &cb_param) override {
|
||||
all_names_.push_back("BGN");
|
||||
all_step_nums_.push_back(cb_param.cur_step_num_);
|
||||
all_ep_nums_.push_back(cb_param.cur_epoch_num_);
|
||||
return Status::OK();
|
||||
}
|
||||
Status DSEpochBegin(const CallbackParam &cb_param) override {
|
||||
all_names_.push_back("EPBGN");
|
||||
all_step_nums_.push_back(cb_param.cur_step_num_);
|
||||
all_ep_nums_.push_back(cb_param.cur_epoch_num_);
|
||||
return Status::OK();
|
||||
}
|
||||
Status DSNStepBegin(const CallbackParam &cb_param) override {
|
||||
all_names_.push_back("SPBGN");
|
||||
all_step_nums_.push_back(cb_param.cur_step_num_);
|
||||
all_ep_nums_.push_back(cb_param.cur_epoch_num_);
|
||||
return Status::OK();
|
||||
}
|
||||
Status DSEnd(const CallbackParam &cb_param) override {
|
||||
all_names_.push_back("END");
|
||||
all_step_nums_.push_back(cb_param.cur_step_num_);
|
||||
all_ep_nums_.push_back(cb_param.cur_epoch_num_);
|
||||
return Status::OK();
|
||||
}
|
||||
Status DSEpochEnd(const CallbackParam &cb_param) override {
|
||||
all_names_.push_back("EPEND");
|
||||
all_step_nums_.push_back(cb_param.cur_step_num_);
|
||||
all_ep_nums_.push_back(cb_param.cur_epoch_num_);
|
||||
return Status::OK();
|
||||
}
|
||||
Status DSNStepEnd(const CallbackParam &cb_param) override {
|
||||
all_names_.push_back("SPEND");
|
||||
all_step_nums_.push_back(cb_param.cur_step_num_);
|
||||
all_ep_nums_.push_back(cb_param.cur_epoch_num_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool IsBeginNeeded() override { return begin_; }
|
||||
bool IsEpochBeginNeeded() override { return epoch_begin_; }
|
||||
bool IsNStepBeginNeeded() override { return step_begin_; }
|
||||
bool IsEndNeeded() override { return end_; }
|
||||
bool IsEpochEndNeeded() override { return epoch_end_; }
|
||||
bool IsNStepEndNeeded() override { return step_end_; }
|
||||
|
||||
std::vector<std::string> all_names(size_t len) {
|
||||
return std::vector<std::string>(all_names_.begin(), all_names_.begin() + len);
|
||||
}
|
||||
|
||||
std::vector<int64_t> all_step_nums(size_t len) {
|
||||
return std::vector<int64_t>(all_step_nums_.begin(), all_step_nums_.begin() + len);
|
||||
}
|
||||
|
||||
std::vector<int64_t> all_ep_nums(size_t len) {
|
||||
return std::vector<int64_t>(all_ep_nums_.begin(), all_ep_nums_.begin() + len);
|
||||
}
|
||||
|
||||
// flag for turning callback on and off
|
||||
bool begin_, epoch_begin_, step_begin_, end_, epoch_end_, step_end_;
|
||||
// name of the callback function in sequence, BGN, EPBGN, SPB, END, EPEND, SPEND
|
||||
std::vector<std::string> all_names_;
|
||||
std::vector<int64_t> all_step_nums_, all_ep_nums_;
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
class MindDataTestCallback : public UT::DatasetOpTesting {
|
||||
public:
|
||||
void SetUp() override {
|
||||
DatasetOpTesting::SetUp();
|
||||
GlobalInit();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestCallback, TestBasicCallback) {
|
||||
// config callback
|
||||
Status rc;
|
||||
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(64);
|
||||
std::shared_ptr<DSCallback> cb1 = tst_cb;
|
||||
tst_cb->end_ = false; // don't do the end for now due to a timing issue
|
||||
// config leaf_op, use random_data to avoid I/O
|
||||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
TensorShape shape({}); // empty shape is a 1-value scalar Tensor
|
||||
ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape);
|
||||
schema->AddColumn(col);
|
||||
std::shared_ptr<RandomDataOp> leaf;
|
||||
rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(44).Build(&leaf);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
// config mapOp
|
||||
std::shared_ptr<MapOp> map_op;
|
||||
auto map_b = MapOp::Builder();
|
||||
rc = map_b.SetInColNames({"label"}).SetTensorFuncs({std::make_shared<NoOp>()}).AddCallbacks({cb1}).Build(&map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
// config RepeatOp
|
||||
std::shared_ptr<RepeatOp> repeat_op;
|
||||
rc = RepeatOp::Builder(2).Build(&repeat_op);
|
||||
// start build then launch tree
|
||||
std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op});
|
||||
rc = tree->Prepare();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = tree->Launch();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
DatasetIterator di(tree);
|
||||
TensorMap tensor_map;
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
while (!tensor_map.empty()) {
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
}
|
||||
|
||||
std::vector<std::string> callback_names = {"BGN", "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND"};
|
||||
std::vector<int64_t> all_steps = {0, 0, 1, 1, 65, 65, 88};
|
||||
std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 1, 1};
|
||||
// doing resize to make sure no unexpected epoch_end or extra epoch_begin is called
|
||||
size_t len = 7;
|
||||
EXPECT_EQ(tst_cb->all_names(len), callback_names);
|
||||
EXPECT_EQ(tst_cb->all_step_nums(len), all_steps);
|
||||
EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestCallback, TestMutiEpochCallback) {
|
||||
// config callback
|
||||
Status rc;
|
||||
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(4);
|
||||
std::shared_ptr<DSCallback> cb1 = tst_cb;
|
||||
tst_cb->end_ = false; // don't do the end for now due to a timing issue
|
||||
// config leaf_op, use random_data to avoid I/O
|
||||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
TensorShape shape({}); // empty shape is a 1-value scalar Tensor
|
||||
ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape);
|
||||
schema->AddColumn(col);
|
||||
std::shared_ptr<RandomDataOp> leaf;
|
||||
rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
// config mapOp
|
||||
std::shared_ptr<MapOp> map_op;
|
||||
auto map_b = MapOp::Builder();
|
||||
rc = map_b.SetInColNames({"label"}).SetTensorFuncs({std::make_shared<NoOp>()}).AddCallbacks({cb1}).Build(&map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
// config RepeatOp
|
||||
std::shared_ptr<RepeatOp> repeat_op;
|
||||
rc = RepeatOp::Builder(2).Build(&repeat_op);
|
||||
// start build then launch tree
|
||||
std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op});
|
||||
rc = tree->Prepare();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = tree->Launch();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
DatasetIterator di(tree);
|
||||
TensorMap tensor_map;
|
||||
size_t num_epochs = 2;
|
||||
for (int ep_num = 0; ep_num < num_epochs; ++ep_num) {
|
||||
di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
while (tensor_map.size() != 0) {
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> callback_names = {"BGN", "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND",
|
||||
"EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND"};
|
||||
|
||||
std::vector<int64_t> all_steps = {0, 0, 1, 1, 5, 5, 8, 8, 9, 9, 13, 13, 16};
|
||||
std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2};
|
||||
|
||||
size_t len = 13;
|
||||
EXPECT_EQ(tst_cb->all_names(len), callback_names);
|
||||
EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs);
|
||||
EXPECT_EQ(tst_cb->all_step_nums(len), all_steps);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestCallback, TestSelectedCallback) {
|
||||
// config callback
|
||||
Status rc;
|
||||
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(4);
|
||||
std::shared_ptr<DSCallback> cb1 = tst_cb;
|
||||
tst_cb->end_ = false;
|
||||
// turn off the epochs
|
||||
tst_cb->epoch_begin_ = false;
|
||||
tst_cb->epoch_end_ = false;
|
||||
|
||||
// config leaf_op, use random_data to avoid I/O
|
||||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
TensorShape shape({}); // empty shape is a 1-value scalar Tensor
|
||||
ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape);
|
||||
schema->AddColumn(col);
|
||||
std::shared_ptr<RandomDataOp> leaf;
|
||||
rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
// config mapOp
|
||||
std::shared_ptr<MapOp> map_op;
|
||||
auto map_b = MapOp::Builder();
|
||||
rc = map_b.SetInColNames({"label"}).SetTensorFuncs({std::make_shared<NoOp>()}).AddCallbacks({cb1}).Build(&map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
// config RepeatOp
|
||||
std::shared_ptr<RepeatOp> repeat_op;
|
||||
rc = RepeatOp::Builder(2).Build(&repeat_op);
|
||||
// start build then launch tree
|
||||
std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op});
|
||||
rc = tree->Prepare();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = tree->Launch();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
// Start the loop of reading tensors from our pipeline
|
||||
DatasetIterator di(tree);
|
||||
TensorMap tensor_map;
|
||||
size_t num_epochs = 2;
|
||||
for (int ep_num = 0; ep_num < num_epochs; ++ep_num) {
|
||||
di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
|
||||
while (tensor_map.size() != 0) {
|
||||
rc = di.GetNextAsMap(&tensor_map);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> callback_names = {"BGN", "SPBGN", "SPEND", "SPBGN", "SPEND",
|
||||
"SPBGN", "SPEND", "SPBGN", "SPEND"};
|
||||
|
||||
std::vector<int64_t> all_steps = {0, 1, 1, 5, 5, 9, 9, 13, 13};
|
||||
std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 2, 2, 2, 2};
|
||||
|
||||
size_t len = 9;
|
||||
EXPECT_EQ(tst_cb->all_names(len), callback_names);
|
||||
EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs);
|
||||
EXPECT_EQ(tst_cb->all_step_nums(len), all_steps);
|
||||
}
|
|
@ -0,0 +1,365 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from builtins import range, super
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore.dataset.callback import DSCallback, WaitedDSCallback
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.nn as nn
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
class MyDSCallback(DSCallback):
|
||||
def __init__(self, step_size=1, events=None, cb_id=0):
|
||||
super().__init__(step_size)
|
||||
self.events = events
|
||||
self.cb_id = cb_id
|
||||
|
||||
def append(self, event_name, ds_run_context):
|
||||
event = [event_name, ds_run_context.cur_epoch_num,
|
||||
ds_run_context.cur_step_num_in_epoch, ds_run_context.cur_step_num]
|
||||
event = '_'.join([str(e) for e in event])
|
||||
index = -1
|
||||
for i, e in enumerate(self.events):
|
||||
if e[0] == event:
|
||||
index = i
|
||||
break
|
||||
if index != -1:
|
||||
self.events[index][1].append(self.cb_id)
|
||||
else:
|
||||
self.events.append((event, [self.cb_id]))
|
||||
|
||||
def ds_begin(self, ds_run_context):
|
||||
self.append("begin", ds_run_context)
|
||||
|
||||
def ds_end(self, ds_run_context):
|
||||
self.append("end", ds_run_context)
|
||||
|
||||
def ds_epoch_begin(self, ds_run_context):
|
||||
self.append("epoch_begin", ds_run_context)
|
||||
|
||||
def ds_epoch_end(self, ds_run_context):
|
||||
self.append("epoch_end", ds_run_context)
|
||||
|
||||
def ds_step_begin(self, ds_run_context):
|
||||
self.append("step_begin", ds_run_context)
|
||||
|
||||
def ds_step_end(self, ds_run_context):
|
||||
self.append("step_end", ds_run_context)
|
||||
|
||||
|
||||
def generate_expected(epoch_num, step_num, step_size=1, map_num=1, repeat=1):
|
||||
events = []
|
||||
cb_id = list(range(map_num))
|
||||
|
||||
def append(name, e, s):
|
||||
event = [name, e + 1, s + 1, e * step_num * repeat + s + 1]
|
||||
event = '_'.join([str(ev) for ev in event])
|
||||
events.append((event, cb_id))
|
||||
|
||||
events.append(("begin_0_0_0", cb_id))
|
||||
for e in range(epoch_num):
|
||||
append("epoch_begin", e, -1)
|
||||
for s in range(step_num * repeat):
|
||||
if s % step_size == 0:
|
||||
append("step_begin", e, s)
|
||||
append("step_end", e, s)
|
||||
append("epoch_end", e, step_num * repeat - 1)
|
||||
return events
|
||||
|
||||
|
||||
def build_test_case_1cb(epochs, steps, step_size=1, repeat=1):
|
||||
events = []
|
||||
|
||||
arr = list(range(1, steps + 1))
|
||||
data = ds.NumpySlicesDataset(arr, shuffle=False)
|
||||
|
||||
my_cb = MyDSCallback(step_size=step_size, events=events)
|
||||
|
||||
data = data.map(operations=(lambda x: x), callbacks=my_cb)
|
||||
if repeat != 1:
|
||||
data = data.repeat(repeat)
|
||||
itr = data.create_tuple_iterator(num_epochs=epochs)
|
||||
for _ in range(epochs):
|
||||
for _ in itr:
|
||||
pass
|
||||
|
||||
expected_events = generate_expected(epochs, steps, step_size, 1, repeat)
|
||||
assert expected_events == events
|
||||
|
||||
|
||||
def build_test_case_2cbs(epochs, steps):
|
||||
events1 = []
|
||||
events2 = []
|
||||
my_cb1 = MyDSCallback(events=events1)
|
||||
my_cb2 = MyDSCallback(events=events2)
|
||||
|
||||
arr = list(range(1, steps + 1))
|
||||
data = ds.NumpySlicesDataset(arr, shuffle=False)
|
||||
|
||||
data = data.map(operations=(lambda x: x), callbacks=[my_cb1, my_cb2])
|
||||
|
||||
itr = data.create_tuple_iterator(num_epochs=epochs)
|
||||
for _ in range(epochs):
|
||||
for _ in itr:
|
||||
pass
|
||||
|
||||
expected_events = generate_expected(epochs, steps)
|
||||
assert expected_events == events1
|
||||
assert expected_events == events2
|
||||
|
||||
|
||||
def build_test_case_2maps(epochs, steps):
|
||||
events = []
|
||||
my_cb1 = MyDSCallback(events=events, cb_id=0)
|
||||
my_cb2 = MyDSCallback(events=events, cb_id=1)
|
||||
|
||||
arr = list(range(1, steps + 1))
|
||||
data = ds.NumpySlicesDataset(arr, shuffle=False)
|
||||
|
||||
data = data.map(operations=(lambda x: x), callbacks=my_cb1)
|
||||
data = data.map(operations=(lambda x: x), callbacks=my_cb2)
|
||||
|
||||
itr = data.create_tuple_iterator(num_epochs=epochs)
|
||||
for _ in range(epochs):
|
||||
for _ in itr:
|
||||
pass
|
||||
|
||||
expected_events = generate_expected(epochs, steps, map_num=2)
|
||||
|
||||
assert expected_events[1:] == events[1:]
|
||||
|
||||
for event in events:
|
||||
assert len(event) == 2
|
||||
event, cb_ids = event
|
||||
if event != "begin_0_0_0":
|
||||
assert cb_ids[0] == 0
|
||||
assert cb_ids[1] == 1
|
||||
|
||||
|
||||
def test_callbacks_all_methods():
|
||||
logger.info("test_callbacks_all_methods")
|
||||
|
||||
build_test_case_1cb(1, 1)
|
||||
build_test_case_1cb(1, 2)
|
||||
build_test_case_1cb(1, 3)
|
||||
build_test_case_1cb(1, 4)
|
||||
|
||||
build_test_case_1cb(2, 1)
|
||||
build_test_case_1cb(2, 2)
|
||||
build_test_case_1cb(2, 3)
|
||||
build_test_case_1cb(2, 4)
|
||||
|
||||
build_test_case_1cb(3, 1)
|
||||
build_test_case_1cb(3, 2)
|
||||
build_test_case_1cb(3, 3)
|
||||
build_test_case_1cb(3, 4)
|
||||
|
||||
|
||||
def test_callbacks_var_step_size():
|
||||
logger.info("test_callbacks_var_step_size")
|
||||
|
||||
build_test_case_1cb(1, 2, 2)
|
||||
build_test_case_1cb(1, 3, 2)
|
||||
build_test_case_1cb(1, 4, 2)
|
||||
|
||||
build_test_case_1cb(2, 2, 2)
|
||||
build_test_case_1cb(2, 3, 2)
|
||||
build_test_case_1cb(2, 4, 2)
|
||||
|
||||
build_test_case_1cb(3, 2, 2)
|
||||
build_test_case_1cb(3, 3, 2)
|
||||
build_test_case_1cb(3, 4, 2)
|
||||
|
||||
|
||||
def test_callbacks_all_2cbs():
|
||||
logger.info("test_callbacks_all_2cbs")
|
||||
|
||||
build_test_case_2cbs(4, 1)
|
||||
build_test_case_2cbs(4, 2)
|
||||
build_test_case_2cbs(4, 3)
|
||||
build_test_case_2cbs(4, 4)
|
||||
|
||||
|
||||
def test_callbacks_2maps():
|
||||
logger.info("test_callbacks_2maps")
|
||||
|
||||
build_test_case_2maps(5, 10)
|
||||
|
||||
build_test_case_2maps(6, 9)
|
||||
|
||||
|
||||
class MyWaitedCallback(WaitedDSCallback):
|
||||
def __init__(self, events, step_size=1):
|
||||
super().__init__(step_size)
|
||||
self.events = events
|
||||
|
||||
def sync_epoch_begin(self, train_run_context, ds_run_context):
|
||||
event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
|
||||
self.events.append(event)
|
||||
|
||||
def sync_step_begin(self, train_run_context, ds_run_context):
|
||||
event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
|
||||
self.events.append(event)
|
||||
|
||||
|
||||
class MyMSCallback(Callback):
|
||||
def __init__(self, events):
|
||||
self.events = events
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
|
||||
self.events.append(event)
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
|
||||
self.events.append(event)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return x
|
||||
|
||||
|
||||
def test_train_non_sink():
|
||||
logger.info("test_train_non_sink")
|
||||
|
||||
events = []
|
||||
my_cb1 = MyWaitedCallback(events, 1)
|
||||
my_cb2 = MyMSCallback(events)
|
||||
arr = [1, 2, 3, 4]
|
||||
data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
|
||||
data = data.map(operations=(lambda x: x), callbacks=my_cb1)
|
||||
|
||||
net = Net()
|
||||
model = Model(net)
|
||||
|
||||
model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
|
||||
expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3',
|
||||
'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4',
|
||||
'ms_epoch_end_1_4', 'ds_epoch_begin_2_4',
|
||||
'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6',
|
||||
'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8',
|
||||
'ms_step_end_2_8', 'ms_epoch_end_2_8']
|
||||
|
||||
assert events == expected_synced_events
|
||||
|
||||
|
||||
def test_train_batch_size2():
|
||||
logger.info("test_train_batch_size2")
|
||||
|
||||
events = []
|
||||
my_cb1 = MyWaitedCallback(events, 2)
|
||||
my_cb2 = MyMSCallback(events)
|
||||
arr = [1, 2, 3, 4]
|
||||
data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
|
||||
data = data.map(operations=(lambda x: x), callbacks=my_cb1)
|
||||
data = data.batch(2)
|
||||
net = Net()
|
||||
model = Model(net)
|
||||
|
||||
model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
|
||||
|
||||
expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_3',
|
||||
'ms_step_end_1_2',
|
||||
'ms_epoch_end_1_2', 'ds_epoch_begin_2_4',
|
||||
'ds_step_begin_2_5', 'ms_step_end_2_3', 'ds_step_begin_2_7',
|
||||
'ms_step_end_2_4', 'ms_epoch_end_2_4']
|
||||
|
||||
assert events == expected_synced_events
|
||||
|
||||
|
||||
def test_callbacks_validations():
|
||||
logger.info("test_callbacks_validations")
|
||||
|
||||
with pytest.raises(Exception) as err:
|
||||
data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
|
||||
data.map(operations=(lambda x: x), callbacks=0)
|
||||
assert "Argument callbacks with value 0 is not " in str(err.value)
|
||||
|
||||
with pytest.raises(Exception) as err:
|
||||
my_cb1 = MyDSCallback()
|
||||
data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
|
||||
data.map(operations=(lambda x: x), callbacks=[my_cb1, 0])
|
||||
assert "Argument callbacks[1] with value 0 is not " in str(err.value)
|
||||
|
||||
with pytest.raises(Exception) as err:
|
||||
class BadCB(DSCallback):
|
||||
pass
|
||||
|
||||
my_cb = BadCB()
|
||||
|
||||
data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
|
||||
data = data.map(operations=(lambda x: x), callbacks=my_cb)
|
||||
for _ in data:
|
||||
pass
|
||||
assert "Provided Callback class did not override any of the 6 callback methods." in str(err.value)
|
||||
|
||||
|
||||
def test_callback_sink_simulation():
|
||||
logger.info("test_callback_sink_simulation")
|
||||
|
||||
events = []
|
||||
epochs = 2
|
||||
my_cb = MyWaitedCallback(events, 1)
|
||||
data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
|
||||
data = data.map(operations=(lambda x: x), callbacks=my_cb)
|
||||
data = data.to_device()
|
||||
data.send(num_epochs=epochs)
|
||||
for e in range(epochs):
|
||||
for s in range(4):
|
||||
time.sleep(0.5)
|
||||
events.append(f"ms_step_end_{e + 1}_{e * 4 + s + 1}")
|
||||
my_cb.step_end(run_context=0)
|
||||
events.append(f"ms_epoch_end_{e + 1}_{(e + 1) * 4}")
|
||||
my_cb.epoch_end(run_context=0)
|
||||
expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3',
|
||||
'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4',
|
||||
'ms_epoch_end_1_4', 'ds_epoch_begin_2_4',
|
||||
'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6',
|
||||
'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8',
|
||||
'ms_step_end_2_8', 'ms_epoch_end_2_8']
|
||||
|
||||
assert events == expected_synced_events
|
||||
|
||||
|
||||
def test_callbacks_repeat():
|
||||
logger.info("test_callbacks_repeat")
|
||||
|
||||
build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2)
|
||||
build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=3)
|
||||
build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=3)
|
||||
build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_callbacks_all_methods()
|
||||
test_callbacks_all_2cbs()
|
||||
test_callbacks_2maps()
|
||||
test_callbacks_validations()
|
||||
test_callbacks_var_step_size()
|
||||
test_train_batch_size2()
|
||||
test_callback_sink_simulation()
|
||||
test_callbacks_repeat()
|
Loading…
Reference in New Issue