forked from mindspore-Ecosystem/mindspore
Implemented Callback for Dataset
implment pause in MapOp, added more to callback add ds_callback - Initial drop of Python DSCallback - Pybind DSCallback - Pybind DSCallback added callback to mapOp - de_pipeline DSCallback - de_pipeline DSCallback add test case, segfault for now fix seg fault - de_pipeline DSCallback remove 1 line update callback test case, now works use builder class for mapOp callback - de_pipeline DSCallback - de_pipeline DSCallback - de_pipeline DSCallback better test case minor fix add comments and minor clean ups get rid of nullptr in MapOp, use other flag instead fix a bug ParseMapOp only takes 1 callback - Added WaitedDSCalabck refactor callback param fix text case incorrect number - added testing fix cpp test case - added testing - revert back lenet changes - cleanup test_callbacks.py - cleanup test_callbacks.py fix CI stage I fix CI stage II fix CI and update epoch counter - add validation - add more testing test_callbacks.py use random data op to do tests adjust when to call EpochBegin/End - add repeat with callback - addressing reviewers' comments - docstring and CI fixes - docstring and CI fixes - docstring and CI fixes - rebase with upstream/master fix cpp test case fix review comments addr review cmts, add test case
This commit is contained in:
parent
89cd465268
commit
78c1aa1d96
|
@ -58,6 +58,7 @@ add_subdirectory(kernels)
|
||||||
add_subdirectory(engine)
|
add_subdirectory(engine)
|
||||||
add_subdirectory(api)
|
add_subdirectory(api)
|
||||||
add_subdirectory(text)
|
add_subdirectory(text)
|
||||||
|
add_subdirectory(callback)
|
||||||
######################################################################
|
######################################################################
|
||||||
add_dependencies(utils core)
|
add_dependencies(utils core)
|
||||||
add_dependencies(kernels-image core)
|
add_dependencies(kernels-image core)
|
||||||
|
@ -74,6 +75,7 @@ add_dependencies(engine-cache-server core)
|
||||||
add_dependencies(engine-perf core)
|
add_dependencies(engine-perf core)
|
||||||
add_dependencies(engine-gnn core)
|
add_dependencies(engine-gnn core)
|
||||||
add_dependencies(engine core)
|
add_dependencies(engine core)
|
||||||
|
add_dependencies(callback core)
|
||||||
add_dependencies(text core)
|
add_dependencies(text core)
|
||||||
add_dependencies(text-kernels core)
|
add_dependencies(text-kernels core)
|
||||||
add_dependencies(cpp-API core)
|
add_dependencies(cpp-API core)
|
||||||
|
@ -87,6 +89,7 @@ endif ()
|
||||||
################### Create _c_dataengine Library ######################
|
################### Create _c_dataengine Library ######################
|
||||||
set(submodules
|
set(submodules
|
||||||
$<TARGET_OBJECTS:core>
|
$<TARGET_OBJECTS:core>
|
||||||
|
$<TARGET_OBJECTS:callback>
|
||||||
$<TARGET_OBJECTS:utils>
|
$<TARGET_OBJECTS:utils>
|
||||||
$<TARGET_OBJECTS:kernels>
|
$<TARGET_OBJECTS:kernels>
|
||||||
$<TARGET_OBJECTS:kernels-image>
|
$<TARGET_OBJECTS:kernels-image>
|
||||||
|
@ -135,14 +138,14 @@ endif()
|
||||||
target_link_libraries(_c_dataengine PRIVATE mindspore mindspore_gvar)
|
target_link_libraries(_c_dataengine PRIVATE mindspore mindspore_gvar)
|
||||||
if (${CMAKE_SYSTEM_NAME} MATCHES "Windows")
|
if (${CMAKE_SYSTEM_NAME} MATCHES "Windows")
|
||||||
if (ENABLE_PYTHON)
|
if (ENABLE_PYTHON)
|
||||||
target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} mindspore::protobuf ${SECUREC_LIBRARY})
|
target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} mindspore::protobuf ${SECUREC_LIBRARY})
|
||||||
else()
|
else()
|
||||||
target_link_libraries(_c_dataengine PRIVATE mindspore::protobuf ${SECUREC_LIBRARY})
|
target_link_libraries(_c_dataengine PRIVATE mindspore::protobuf ${SECUREC_LIBRARY})
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
set(ICU_LIB mindspore::icuuc mindspore::icudata mindspore::icui18n)
|
set(ICU_LIB mindspore::icuuc mindspore::icudata mindspore::icui18n)
|
||||||
if (ENABLE_PYTHON)
|
if (ENABLE_PYTHON)
|
||||||
target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY})
|
target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY})
|
||||||
else()
|
else()
|
||||||
target_link_libraries(_c_dataengine PRIVATE -ldl mindspore::protobuf ${SECUREC_LIBRARY})
|
target_link_libraries(_c_dataengine PRIVATE -ldl mindspore::protobuf ${SECUREC_LIBRARY})
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -7,6 +7,7 @@ if (ENABLE_PYTHON)
|
||||||
python/bindings.cc
|
python/bindings.cc
|
||||||
python/bindings/dataset/engine/cache/bindings.cc
|
python/bindings/dataset/engine/cache/bindings.cc
|
||||||
python/bindings/dataset/core/bindings.cc
|
python/bindings/dataset/core/bindings.cc
|
||||||
|
python/bindings/dataset/callback/bindings.cc
|
||||||
python/bindings/dataset/kernels/data/bindings.cc
|
python/bindings/dataset/kernels/data/bindings.cc
|
||||||
python/bindings/dataset/kernels/bindings.cc
|
python/bindings/dataset/kernels/bindings.cc
|
||||||
python/bindings/dataset/engine/datasetops/bindings.cc
|
python/bindings/dataset/engine/datasetops/bindings.cc
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "pybind11/stl_bind.h"
|
||||||
|
|
||||||
|
#include "minddata/dataset/api/python/pybind_register.h"
|
||||||
|
#include "minddata/dataset/callback/py_ds_callback.h"
|
||||||
|
#include "minddata/dataset/callback/ds_callback.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
PYBIND_REGISTER(PyDSCallback, 0, ([](const py::module *m) {
|
||||||
|
(void)py::class_<PyDSCallback, std::shared_ptr<PyDSCallback>>(*m, "PyDSCallback")
|
||||||
|
.def(py::init<int32_t>())
|
||||||
|
.def("set_begin", &PyDSCallback::setBegin)
|
||||||
|
.def("set_end", &PyDSCallback::setEnd)
|
||||||
|
.def("set_epoch_begin", &PyDSCallback::setEpochBegin)
|
||||||
|
.def("set_epoch_end", &PyDSCallback::setEpochEnd)
|
||||||
|
.def("set_step_begin", &PyDSCallback::setStepBegin)
|
||||||
|
.def("set_step_end", &PyDSCallback::setStepEnd);
|
||||||
|
}));
|
||||||
|
|
||||||
|
PYBIND_REGISTER(CallbackParam, 0, ([](const py::module *m) {
|
||||||
|
(void)py::class_<CallbackParam, std::shared_ptr<CallbackParam>>(*m, "CallbackParam")
|
||||||
|
.def(py::init<int64_t, int64_t, int64_t>())
|
||||||
|
.def_readonly("cur_epoch_num", &CallbackParam::cur_epoch_num_)
|
||||||
|
.def_readonly("cur_step_num_in_epoch", &CallbackParam::cur_epoch_step_num_)
|
||||||
|
.def_readonly("cur_step_num", &CallbackParam::cur_step_num_);
|
||||||
|
}));
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -20,6 +20,7 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
#include "utils/ms_utils.h"
|
#include "utils/ms_utils.h"
|
||||||
|
#include "minddata/dataset/callback/py_ds_callback.h"
|
||||||
#include "minddata/dataset/core/tensor.h"
|
#include "minddata/dataset/core/tensor.h"
|
||||||
#include "minddata/dataset/engine/cache/cache_client.h"
|
#include "minddata/dataset/engine/cache/cache_client.h"
|
||||||
#include "minddata/dataset/engine/dataset_iterator.h"
|
#include "minddata/dataset/engine/dataset_iterator.h"
|
||||||
|
@ -738,8 +739,13 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
|
||||||
(void)map_builder.SetTensorFuncs(std::move(tensor_op_list));
|
(void)map_builder.SetTensorFuncs(std::move(tensor_op_list));
|
||||||
} else if (key == "cache") {
|
} else if (key == "cache") {
|
||||||
cache_client = value.cast<std::shared_ptr<CacheClient>>();
|
cache_client = value.cast<std::shared_ptr<CacheClient>>();
|
||||||
|
} else if (key == "callbacks") {
|
||||||
|
std::vector<std::shared_ptr<DSCallback>> callbacks;
|
||||||
|
std::transform(value.begin(), value.end(), std::back_inserter(callbacks),
|
||||||
|
[](py::handle cb) { return cb.cast<std::shared_ptr<PyDSCallback>>(); });
|
||||||
|
(void)map_builder.AddCallbacks(callbacks);
|
||||||
} else {
|
} else {
|
||||||
RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key);
|
RETURN_STATUS_UNEXPECTED("Error in parsing MapOp: Unhandled key: " + key);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||||
|
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||||
|
|
||||||
|
|
||||||
|
if (ENABLE_PYTHON)
|
||||||
|
add_library(callback OBJECT
|
||||||
|
callback_manager.cc
|
||||||
|
py_ds_callback.cc
|
||||||
|
)
|
||||||
|
else ()
|
||||||
|
add_library(callback OBJECT
|
||||||
|
callback_manager.cc
|
||||||
|
)
|
||||||
|
endif ()
|
|
@ -0,0 +1,160 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "minddata/dataset/callback/callback_manager.h"
|
||||||
|
#include "minddata/dataset/callback/ds_callback.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
#include "minddata/dataset/engine/datasetops/dataset_op.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
void CallbackManager::AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) {
|
||||||
|
callbacks_.insert(callbacks_.end(), callbacks.begin(), callbacks.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CallbackManager::Init(std::shared_ptr<DatasetOp> op) {
|
||||||
|
RETURN_UNEXPECTED_IF_NULL(op);
|
||||||
|
op_ = op;
|
||||||
|
// turn the flag on if callback is set
|
||||||
|
enabled_ = !callbacks_.empty();
|
||||||
|
|
||||||
|
// error check for each of the callbacks
|
||||||
|
for (auto &cb : callbacks_) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(cb->step_size() > 0, "callback step_size needs to be greater than 0.");
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CallbackManager::Begin(const CallbackParam &cb_param) {
|
||||||
|
RETURN_OK_IF_TRUE(!enabled_);
|
||||||
|
std::vector<size_t> callback_inds;
|
||||||
|
// go through all callback functions to see if each function is needed
|
||||||
|
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||||
|
if (callbacks_[ind]->IsBeginNeeded()) callback_inds.push_back(ind);
|
||||||
|
}
|
||||||
|
// return Status::OK() if no begin is needed
|
||||||
|
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||||
|
|
||||||
|
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||||
|
|
||||||
|
// Now do the actual callback
|
||||||
|
for (size_t ind : callback_inds) {
|
||||||
|
RETURN_IF_NOT_OK(callbacks_[ind]->DSBegin(cb_param));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CallbackManager::EpochBegin(const CallbackParam &cb_param) {
|
||||||
|
RETURN_OK_IF_TRUE(!enabled_);
|
||||||
|
std::vector<size_t> callback_inds;
|
||||||
|
// go through all callback functions to see if each function is needed
|
||||||
|
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||||
|
if (callbacks_[ind]->IsEpochBeginNeeded()) callback_inds.push_back(ind);
|
||||||
|
}
|
||||||
|
// return Status::OK() if no epoch_begin is needed
|
||||||
|
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||||
|
|
||||||
|
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||||
|
|
||||||
|
// Now do the actual callback
|
||||||
|
for (size_t ind : callback_inds) {
|
||||||
|
RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochBegin(cb_param));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CallbackManager::StepBegin(const CallbackParam &cb_param) {
|
||||||
|
RETURN_OK_IF_TRUE(!enabled_);
|
||||||
|
std::vector<size_t> callback_inds;
|
||||||
|
// go through all callback functions to see if each function is needed
|
||||||
|
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||||
|
if (callbacks_[ind]->IsNStepBeginNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
|
||||||
|
callback_inds.push_back(ind);
|
||||||
|
}
|
||||||
|
// return Status::OK() if no step_begin is needed
|
||||||
|
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||||
|
|
||||||
|
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||||
|
|
||||||
|
// Now do the actual callback
|
||||||
|
for (size_t ind : callback_inds) {
|
||||||
|
RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepBegin(cb_param));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CallbackManager::End(const CallbackParam &cb_param) {
|
||||||
|
RETURN_OK_IF_TRUE(!enabled_);
|
||||||
|
std::vector<size_t> callback_inds;
|
||||||
|
// go through all callback functions to see if each function is needed
|
||||||
|
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||||
|
if (callbacks_[ind]->IsEndNeeded()) callback_inds.push_back(ind);
|
||||||
|
}
|
||||||
|
// return Status::OK() if no end is needed
|
||||||
|
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||||
|
|
||||||
|
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||||
|
|
||||||
|
// Now do the actual callback
|
||||||
|
for (size_t ind : callback_inds) {
|
||||||
|
RETURN_IF_NOT_OK(callbacks_[ind]->DSEnd(cb_param));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CallbackManager::EpochEnd(const CallbackParam &cb_param) {
|
||||||
|
RETURN_OK_IF_TRUE(!enabled_);
|
||||||
|
std::vector<size_t> callback_inds;
|
||||||
|
// go through all callback functions to see if each function is needed
|
||||||
|
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||||
|
if (callbacks_[ind]->IsEpochEndNeeded()) callback_inds.push_back(ind);
|
||||||
|
}
|
||||||
|
// return Status::OK() if no epoch_end is needed
|
||||||
|
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||||
|
|
||||||
|
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||||
|
|
||||||
|
// Now do the actual callback
|
||||||
|
for (size_t ind : callback_inds) {
|
||||||
|
RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochEnd(cb_param));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CallbackManager::StepEnd(const CallbackParam &cb_param) {
|
||||||
|
RETURN_OK_IF_TRUE(!enabled_);
|
||||||
|
std::vector<size_t> callback_inds;
|
||||||
|
// go through all callback functions to see if each function is needed
|
||||||
|
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||||
|
if (callbacks_[ind]->IsNStepEndNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
|
||||||
|
callback_inds.push_back(ind);
|
||||||
|
}
|
||||||
|
// return Status::OK() if no step_end is needed
|
||||||
|
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||||
|
|
||||||
|
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||||
|
|
||||||
|
// Now do the actual callback
|
||||||
|
for (size_t ind : callback_inds) {
|
||||||
|
RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepEnd(cb_param));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,79 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/callback/ds_callback.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
// forward declare to avoid cyclic include of dataset_op.h
|
||||||
|
class DatasetOp;
|
||||||
|
|
||||||
|
/// This class manages all the callbacks that are associated with a single DatasetOp. For now, only MapOp supports this.
|
||||||
|
class CallbackManager {
|
||||||
|
public:
|
||||||
|
/// CallbackManager default constructor. Init needs to be called before using the created instance.
|
||||||
|
CallbackManager() : enabled_(false) {}
|
||||||
|
|
||||||
|
/// \brief
|
||||||
|
/// \param [in] callbacks list of callbacks to perform
|
||||||
|
void AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks);
|
||||||
|
|
||||||
|
/// \brief DatasetOp needs to call Init if it wishes to use callback, Init will set enabled_ to true
|
||||||
|
/// \param[in] op, this pointer is used for Callback Manager to Pause Worker threads
|
||||||
|
/// \return Status
|
||||||
|
Status Init(std::shared_ptr<DatasetOp> op);
|
||||||
|
|
||||||
|
/// \brief callback function called at the start of the first row
|
||||||
|
/// \return Status
|
||||||
|
Status Begin(const CallbackParam &);
|
||||||
|
|
||||||
|
/// \brief callback function called at the start of each epoch
|
||||||
|
/// \return Status
|
||||||
|
Status EpochBegin(const CallbackParam &);
|
||||||
|
|
||||||
|
/// \brief callback function called at the start of each row
|
||||||
|
/// \return Status
|
||||||
|
Status StepBegin(const CallbackParam &);
|
||||||
|
|
||||||
|
/// \brief callback function called after the last row is processed
|
||||||
|
/// \return Status
|
||||||
|
Status End(const CallbackParam &);
|
||||||
|
|
||||||
|
/// \brief callback function called at the end of each epoch
|
||||||
|
/// \return Status
|
||||||
|
Status EpochEnd(const CallbackParam &);
|
||||||
|
|
||||||
|
/// \brief callback function called at the the end of each row
|
||||||
|
/// \return Status
|
||||||
|
Status StepEnd(const CallbackParam &);
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool enabled_; // flag to enable callback, if false, all functions would return immediately
|
||||||
|
std::shared_ptr<DatasetOp> op_; // back pointer to DatasetOp, each DatasetOp has only 1 CallbackManager
|
||||||
|
std::vector<std::shared_ptr<DSCallback>> callbacks_; // list of callbacks the DatasetOp needs to call
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H
|
|
@ -0,0 +1,40 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
/// Callback Param is the object a DatasetOp uses to pass run-time information to user defined function.
|
||||||
|
/// This is a prototype for now, more fields will be added
|
||||||
|
class CallbackParam {
|
||||||
|
public:
|
||||||
|
CallbackParam(int64_t epoch_num, int64_t cur_epoch_step, int64_t total_step_num)
|
||||||
|
: cur_epoch_num_(epoch_num), cur_epoch_step_num_(cur_epoch_step), cur_step_num_(total_step_num) {}
|
||||||
|
|
||||||
|
// these are constant public fields for easy access and consistency with python cb_param
|
||||||
|
// the names and orders are consistent with batchInfo
|
||||||
|
const int64_t cur_epoch_num_; // current epoch
|
||||||
|
const int64_t cur_epoch_step_num_; // step number of the current epoch
|
||||||
|
const int64_t cur_step_num_; // step number since the first row
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H
|
|
@ -0,0 +1,100 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/callback/callback_param.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
class DSCallback {
|
||||||
|
public:
|
||||||
|
/// \brief constructor of DSCallback, this is the base class for all front end specific callbacks
|
||||||
|
/// \param step_size number of steps to call DSNStepBegin()
|
||||||
|
explicit DSCallback(int32_t step_size = 1) : step_size_(step_size) {}
|
||||||
|
|
||||||
|
/// \brief actual callback function for begin, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
virtual Status DSBegin(const CallbackParam &cb_param) = 0;
|
||||||
|
|
||||||
|
/// \brief actual callback function for epoch_begin, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
virtual Status DSEpochBegin(const CallbackParam &cb_param) = 0;
|
||||||
|
|
||||||
|
/// \brief actual callback function for step_begin, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
virtual Status DSNStepBegin(const CallbackParam &cb_param) = 0;
|
||||||
|
|
||||||
|
/// \brief actual callback function for end, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
virtual Status DSEnd(const CallbackParam &cb_param) = 0;
|
||||||
|
|
||||||
|
/// \brief actual callback function epoch_end begin, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
virtual Status DSEpochEnd(const CallbackParam &cb_param) = 0;
|
||||||
|
|
||||||
|
/// \brief actual callback function for step_end, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
virtual Status DSNStepEnd(const CallbackParam &cb_param) = 0;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether begin callback is needed
|
||||||
|
/// \return bool
|
||||||
|
virtual bool IsBeginNeeded() = 0;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether epoch_begin callback is needed
|
||||||
|
/// \return bool
|
||||||
|
virtual bool IsEpochBeginNeeded() = 0;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether step_begin callback is needed
|
||||||
|
/// \return bool
|
||||||
|
virtual bool IsNStepBeginNeeded() = 0;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether end callback is needed
|
||||||
|
/// \return bool
|
||||||
|
virtual bool IsEndNeeded() = 0;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether epoch_end callback is needed
|
||||||
|
/// \return bool
|
||||||
|
virtual bool IsEpochEndNeeded() = 0;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether step_end callback is needed
|
||||||
|
/// \return bool
|
||||||
|
virtual bool IsNStepEndNeeded() = 0;
|
||||||
|
|
||||||
|
/// \brief getter
|
||||||
|
/// \return step_size
|
||||||
|
int32_t step_size() const { return step_size_; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int32_t step_size_; // step begin/end will be called every step_size_
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H
|
|
@ -0,0 +1,86 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "minddata/dataset/callback/callback_manager.h"
|
||||||
|
#include "minddata/dataset/callback/py_ds_callback.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
Status PyDSCallback::DSBegin(const CallbackParam &cb_param) {
|
||||||
|
return PyDSCallback::ExecutePyfunc(begin_func_, cb_param);
|
||||||
|
}
|
||||||
|
Status PyDSCallback::DSEpochBegin(const CallbackParam &cb_param) {
|
||||||
|
return PyDSCallback::ExecutePyfunc(epoch_begin_func_, cb_param);
|
||||||
|
}
|
||||||
|
Status PyDSCallback::DSNStepBegin(const CallbackParam &cb_param) {
|
||||||
|
return PyDSCallback::ExecutePyfunc(step_begin_func_, cb_param);
|
||||||
|
}
|
||||||
|
Status PyDSCallback::DSEnd(const CallbackParam &cb_param) { return PyDSCallback::ExecutePyfunc(end_func_, cb_param); }
|
||||||
|
|
||||||
|
Status PyDSCallback::DSEpochEnd(const CallbackParam &cb_param) {
|
||||||
|
return PyDSCallback::ExecutePyfunc(epoch_end_func_, cb_param);
|
||||||
|
}
|
||||||
|
Status PyDSCallback::DSNStepEnd(const CallbackParam &cb_param) {
|
||||||
|
return PyDSCallback::ExecutePyfunc(step_end_func_, cb_param);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PyDSCallback::IsBeginNeeded() { return begin_needed_; }
|
||||||
|
bool PyDSCallback::IsEpochBeginNeeded() { return epoch_begin_needed_; }
|
||||||
|
bool PyDSCallback::IsNStepBeginNeeded() { return step_begin_needed_; }
|
||||||
|
bool PyDSCallback::IsNStepEndNeeded() { return step_end_needed_; }
|
||||||
|
bool PyDSCallback::IsEpochEndNeeded() { return epoch_end_needed_; }
|
||||||
|
bool PyDSCallback::IsEndNeeded() { return end_needed_; }
|
||||||
|
|
||||||
|
Status PyDSCallback::ExecutePyfunc(py::function f, const CallbackParam &cb_param) {
|
||||||
|
{
|
||||||
|
// Acquire Python GIL
|
||||||
|
py::gil_scoped_acquire gil_acquire;
|
||||||
|
if (Py_IsInitialized() == 0) {
|
||||||
|
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
|
||||||
|
}
|
||||||
|
f(cb_param);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
void PyDSCallback::setBegin(py::function f) {
|
||||||
|
begin_func_ = f;
|
||||||
|
begin_needed_ = true;
|
||||||
|
}
|
||||||
|
void PyDSCallback::setEnd(py::function f) {
|
||||||
|
end_func_ = f;
|
||||||
|
end_needed_ = true;
|
||||||
|
}
|
||||||
|
void PyDSCallback::setEpochBegin(py::function f) {
|
||||||
|
epoch_begin_func_ = f;
|
||||||
|
epoch_begin_needed_ = true;
|
||||||
|
}
|
||||||
|
void PyDSCallback::setEpochEnd(py::function f) {
|
||||||
|
epoch_end_func_ = f;
|
||||||
|
epoch_end_needed_ = true;
|
||||||
|
}
|
||||||
|
void PyDSCallback::setStepBegin(py::function f) {
|
||||||
|
step_begin_func_ = f;
|
||||||
|
step_begin_needed_ = true;
|
||||||
|
}
|
||||||
|
void PyDSCallback::setStepEnd(py::function f) {
|
||||||
|
step_end_func_ = f;
|
||||||
|
step_end_needed_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,130 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/callback/ds_callback.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
class PyDSCallback : public DSCallback {
|
||||||
|
public:
|
||||||
|
/// \brief constructor for PyDSCallback. This callback is for python front end
|
||||||
|
explicit PyDSCallback(int32_t step_size = 1)
|
||||||
|
: DSCallback(step_size),
|
||||||
|
begin_needed_(false),
|
||||||
|
epoch_begin_needed_(false),
|
||||||
|
step_begin_needed_(false),
|
||||||
|
end_needed_(false),
|
||||||
|
epoch_end_needed_(false),
|
||||||
|
step_end_needed_(false) {}
|
||||||
|
|
||||||
|
void setBegin(py::function f);
|
||||||
|
void setEnd(py::function f);
|
||||||
|
void setEpochBegin(py::function f);
|
||||||
|
void setEpochEnd(py::function f);
|
||||||
|
void setStepBegin(py::function f);
|
||||||
|
void setStepEnd(py::function f);
|
||||||
|
|
||||||
|
/// \brief actual callback function for begin, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
Status DSBegin(const CallbackParam &cb_param) override;
|
||||||
|
|
||||||
|
/// \brief actual callback function for epoch_begin, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
Status DSEpochBegin(const CallbackParam &cb_param) override;
|
||||||
|
|
||||||
|
/// \brief actual callback function for step_begin, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
Status DSNStepBegin(const CallbackParam &cb_param) override;
|
||||||
|
|
||||||
|
/// \brief actual callback function for end, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
Status DSEnd(const CallbackParam &cb_param) override;
|
||||||
|
|
||||||
|
/// \brief actual callback function epoch_end begin, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
Status DSEpochEnd(const CallbackParam &cb_param) override;
|
||||||
|
|
||||||
|
/// \brief actual callback function for step_end, needs to be overridden in the derived class
|
||||||
|
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||||
|
/// \return Status
|
||||||
|
Status DSNStepEnd(const CallbackParam &cb_param) override;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether begin callback is needed
|
||||||
|
/// \return bool
|
||||||
|
bool IsBeginNeeded() override;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether epoch_begin callback is needed
|
||||||
|
/// \return bool
|
||||||
|
bool IsEpochBeginNeeded() override;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether step_begin callback is needed
|
||||||
|
/// \return bool
|
||||||
|
bool IsNStepBeginNeeded() override;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether end callback is needed
|
||||||
|
/// \return bool
|
||||||
|
bool IsEndNeeded() override;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether epoch_end callback is needed
|
||||||
|
/// \return bool
|
||||||
|
bool IsEpochEndNeeded() override;
|
||||||
|
|
||||||
|
/// \brief predicate function, whether step_end callback is needed
|
||||||
|
/// \return bool
|
||||||
|
bool IsNStepEndNeeded() override;
|
||||||
|
|
||||||
|
/// \brief helper function to acquire GIL then execute a pyfunc
|
||||||
|
/// \param f the python function
|
||||||
|
/// \param cb_param
|
||||||
|
/// \return Status
|
||||||
|
static Status ExecutePyfunc(py::function f, const CallbackParam &cb_param);
|
||||||
|
|
||||||
|
private:
|
||||||
|
py::function begin_func_;
|
||||||
|
py::function epoch_begin_func_;
|
||||||
|
py::function step_begin_func_;
|
||||||
|
py::function end_func_;
|
||||||
|
py::function epoch_end_func_;
|
||||||
|
py::function step_end_func_;
|
||||||
|
|
||||||
|
bool begin_needed_;
|
||||||
|
bool epoch_begin_needed_;
|
||||||
|
bool step_begin_needed_;
|
||||||
|
bool end_needed_;
|
||||||
|
bool epoch_end_needed_;
|
||||||
|
bool step_end_needed_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H
|
|
@ -21,6 +21,8 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/callback/callback_manager.h"
|
||||||
#include "minddata/dataset/core/constants.h"
|
#include "minddata/dataset/core/constants.h"
|
||||||
#include "minddata/dataset/engine/db_connector.h"
|
#include "minddata/dataset/engine/db_connector.h"
|
||||||
#include "minddata/dataset/util/status.h"
|
#include "minddata/dataset/util/status.h"
|
||||||
|
@ -358,6 +360,14 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
||||||
/// \return boolean returns true if it's last iteration
|
/// \return boolean returns true if it's last iteration
|
||||||
bool IsLastIteration() { return op_total_repeats_ == op_current_repeats_ + 1; }
|
bool IsLastIteration() { return op_total_repeats_ == op_current_repeats_ + 1; }
|
||||||
|
|
||||||
|
/// This function is only intended to be called by CallbackManager within the master thread of ParallelOp
|
||||||
|
/// The expected behavior is this, when this function is invoked, this function will block until all the workers
|
||||||
|
/// have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master.
|
||||||
|
/// They would automatically wait on the QueueList when they are done. Hence, for now, a Unpause() function is not
|
||||||
|
/// needed. Only parallelOp needs to override this function.
|
||||||
|
/// \return Status
|
||||||
|
virtual Status PauseFromMaster() { return Status::OK(); }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/// \brief Removes a parent operator from this operator
|
/// \brief Removes a parent operator from this operator
|
||||||
/// \notes External callers do not have access to this function
|
/// \notes External callers do not have access to this function
|
||||||
|
@ -394,6 +404,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
||||||
std::unique_ptr<DbConnector> out_connector_; // Output Connector
|
std::unique_ptr<DbConnector> out_connector_; // Output Connector
|
||||||
std::unordered_map<std::string, int32_t> column_name_id_map_; // Mapping between col index and col name
|
std::unordered_map<std::string, int32_t> column_name_id_map_; // Mapping between col index and col name
|
||||||
std::mutex column_name_map_mutex_; // For protecting shared access to the column map
|
std::mutex column_name_map_mutex_; // For protecting shared access to the column map
|
||||||
|
CallbackManager callback_manager_; // Manages callbacks associated with a DatasetOp
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Sets the operator id.
|
/// Sets the operator id.
|
||||||
|
|
|
@ -15,25 +15,23 @@
|
||||||
*/
|
*/
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <iomanip>
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "minddata/dataset/core/config_manager.h"
|
#include "minddata/dataset/core/config_manager.h"
|
||||||
|
|
||||||
|
#include "minddata/dataset/callback/callback_param.h"
|
||||||
#include "minddata/dataset/core/constants.h"
|
#include "minddata/dataset/core/constants.h"
|
||||||
#include "minddata/dataset/core/global_context.h"
|
#include "minddata/dataset/core/global_context.h"
|
||||||
#include "minddata/dataset/core/tensor.h"
|
|
||||||
#include "minddata/dataset/engine/data_buffer.h"
|
#include "minddata/dataset/engine/data_buffer.h"
|
||||||
#include "minddata/dataset/engine/db_connector.h"
|
|
||||||
#include "minddata/dataset/engine/execution_tree.h"
|
|
||||||
#include "minddata/dataset/engine/opt/pass.h"
|
|
||||||
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
|
|
||||||
#include "minddata/dataset/engine/datasetops/map_op/cpu_map_job.h"
|
#include "minddata/dataset/engine/datasetops/map_op/cpu_map_job.h"
|
||||||
#include "minddata/dataset/engine/datasetops/map_op/gpu_map_job.h"
|
#include "minddata/dataset/engine/datasetops/map_op/gpu_map_job.h"
|
||||||
|
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
|
||||||
|
#include "minddata/dataset/engine/execution_tree.h"
|
||||||
|
#include "minddata/dataset/engine/opt/pass.h"
|
||||||
#include "minddata/dataset/kernels/tensor_op.h"
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
#include "utils/log_adapter.h"
|
|
||||||
#include "minddata/dataset/util/task_manager.h"
|
#include "minddata/dataset/util/task_manager.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
@ -58,6 +56,7 @@ Status MapOp::Builder::Build(std::shared_ptr<MapOp> *ptr) {
|
||||||
RETURN_IF_NOT_OK(sanityCheck());
|
RETURN_IF_NOT_OK(sanityCheck());
|
||||||
*ptr = std::make_shared<MapOp>(std::move(build_in_col_names_), std::move(build_out_col_names_),
|
*ptr = std::make_shared<MapOp>(std::move(build_in_col_names_), std::move(build_out_col_names_),
|
||||||
std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_);
|
std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_);
|
||||||
|
(*ptr)->callback_manager_.AddCallbacks(std::move(builder_callbacks_));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -164,7 +163,10 @@ Status MapOp::GenerateWorkerJob(const std::unique_ptr<MapWorkerJob> *worker_job)
|
||||||
Status MapOp::operator()() {
|
Status MapOp::operator()() {
|
||||||
// Create and register the local queues.
|
// Create and register the local queues.
|
||||||
local_queues_.Init(num_workers_, oc_queue_size_);
|
local_queues_.Init(num_workers_, oc_queue_size_);
|
||||||
|
// init callback
|
||||||
|
RETURN_IF_NOT_OK(callback_manager_.Init(shared_from_this()));
|
||||||
Status rc = local_queues_.Register(tree_->AllTasks());
|
Status rc = local_queues_.Register(tree_->AllTasks());
|
||||||
|
RETURN_IF_NOT_OK(master_pause_wp_.Register(tree_->AllTasks()));
|
||||||
if (rc.IsError()) {
|
if (rc.IsError()) {
|
||||||
TaskManager::FindMe()->Post();
|
TaskManager::FindMe()->Post();
|
||||||
return rc;
|
return rc;
|
||||||
|
@ -175,28 +177,51 @@ Status MapOp::operator()() {
|
||||||
// Synchronize with TaskManager
|
// Synchronize with TaskManager
|
||||||
TaskManager::FindMe()->Post();
|
TaskManager::FindMe()->Post();
|
||||||
RETURN_IF_NOT_OK(rc);
|
RETURN_IF_NOT_OK(rc);
|
||||||
|
// num_buffers received, including eoe, num_epoch, num_step of current epoch
|
||||||
|
int64_t num_buf = 0, ep_step = 0, total_step = 0;
|
||||||
|
RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step)));
|
||||||
|
|
||||||
int64_t que_id = 0;
|
|
||||||
std::unique_ptr<DataBuffer> buff;
|
std::unique_ptr<DataBuffer> buff;
|
||||||
bool is_eof = false;
|
|
||||||
// Drain output connector of the previous op, generate jobs for worker threads, and distribute them via local queues
|
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
|
||||||
// Stop when all worker threads are finished (received EOF)
|
while (!buff->eof()) {
|
||||||
while (!is_eof) {
|
if (op_current_repeats_ % op_num_repeats_per_epoch() == 0) {
|
||||||
|
RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||||
|
}
|
||||||
|
while (!buff->eoe()) {
|
||||||
|
ep_step++;
|
||||||
|
total_step++;
|
||||||
|
// Create an empty map worker job to be populated by a databuffer and map jobs
|
||||||
|
RETURN_IF_NOT_OK(callback_manager_.StepBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||||
|
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff));
|
||||||
|
|
||||||
|
// Populate map worker job for a worker to execute
|
||||||
|
RETURN_IF_NOT_OK(GenerateWorkerJob(&worker_job));
|
||||||
|
|
||||||
|
// Push map worker job to the corresponding worker's queue
|
||||||
|
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job)));
|
||||||
|
RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||||
|
|
||||||
|
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
// send the eoe buffer to worker
|
||||||
|
|
||||||
|
// reset epoch_step when a new epoch is about to start
|
||||||
|
if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) {
|
||||||
|
RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||||
|
ep_step = 0;
|
||||||
|
}
|
||||||
|
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff));
|
||||||
|
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job)));
|
||||||
|
UpdateRepeatAndEpochCounter();
|
||||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
|
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
|
||||||
is_eof = buff->eof();
|
|
||||||
|
|
||||||
// Create an empty map worker job to be populated by a databuffer and map jobs
|
|
||||||
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>();
|
|
||||||
worker_job->databuffer = std::move(buff);
|
|
||||||
|
|
||||||
// Populate map worker job for a worker to execute
|
|
||||||
RETURN_IF_NOT_OK(GenerateWorkerJob(&worker_job));
|
|
||||||
|
|
||||||
// Push map worker job to the corresponding worker's queue
|
|
||||||
RETURN_IF_NOT_OK(local_queues_[que_id]->Add(std::move(worker_job)));
|
|
||||||
que_id = (que_id + 1) % num_workers_;
|
|
||||||
}
|
}
|
||||||
|
// the last eoe increments the eoe count by 1, but this shouldn't be reflected on End() callback
|
||||||
|
// RETURN_IF_NOT_OK(callback_manager_.End(CallbackParam(op_current_epochs_, ep_step, total_step)));
|
||||||
|
// handle eof logic
|
||||||
|
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff));
|
||||||
|
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job)));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -213,25 +238,19 @@ Status MapOp::WorkerEntry(int32_t worker_id) {
|
||||||
// Fetch next data buffer and map job list
|
// Fetch next data buffer and map job list
|
||||||
RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list));
|
RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list));
|
||||||
|
|
||||||
// Sanity check the databuffer.
|
|
||||||
// Special case: if there's more threads than buffers, some threads simply get the final control
|
|
||||||
// messages (eoe/eof), and so they will not perform the check.
|
|
||||||
if (!in_buffer->eoe() && !in_buffer->eof()) {
|
|
||||||
int32_t num_rows = in_buffer->NumRows();
|
|
||||||
int32_t num_cols = in_buffer->NumCols();
|
|
||||||
if (num_rows == 0 || num_cols == 0) {
|
|
||||||
RETURN_STATUS_UNEXPECTED("MapOp is getting an empty DataBuffer.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that init work is done, drop into the main fetching loop.
|
// Now that init work is done, drop into the main fetching loop.
|
||||||
// Map op does not use child iterator, and it needs to manually handle eoe and eof's itself
|
// Map op does not use child iterator, and it needs to manually handle eoe and eof's itself
|
||||||
// rather than use the base-class defaults.
|
// rather than use the base-class defaults.
|
||||||
while (true) {
|
while (true) {
|
||||||
// Handle EOE and EOF ourselves. Implicit eoe/eof handling in GetNextInput does not work
|
// handle the pause logic. Pause is triggered when an buffer id of -1 with no special flag and no row is received
|
||||||
// with Performance Mode design.
|
if (in_buffer->id() == -1 && in_buffer->buffer_flags() == DataBuffer::kDeBFlagNone && in_buffer->NumRows() == 0) {
|
||||||
if (in_buffer->eoe()) {
|
// when worker receives the signal from master thread, it increments a atomic int
|
||||||
UpdateRepeatAndEpochCounter();
|
// the last guy who increments the counter, wakes up master thread
|
||||||
|
if (++num_workers_paused_ == num_workers_) master_pause_wp_.Set();
|
||||||
|
// this will block the worker until master thread gives it a new work
|
||||||
|
RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list));
|
||||||
|
continue;
|
||||||
|
} else if (in_buffer->eoe()) {
|
||||||
// Calling base class EoeReceived to forward eoe buffer.
|
// Calling base class EoeReceived to forward eoe buffer.
|
||||||
RETURN_IF_NOT_OK(EoeReceived(worker_id));
|
RETURN_IF_NOT_OK(EoeReceived(worker_id));
|
||||||
// Fetch next data buffer and map job list
|
// Fetch next data buffer and map job list
|
||||||
|
@ -243,6 +262,7 @@ Status MapOp::WorkerEntry(int32_t worker_id) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(in_buffer->NumRows() * in_buffer->NumCols() != 0, "MapOp got an empty DataBuffer.");
|
||||||
std::unique_ptr<TensorQTable> new_tensor_table(std::make_unique<TensorQTable>());
|
std::unique_ptr<TensorQTable> new_tensor_table(std::make_unique<TensorQTable>());
|
||||||
// Perform the compute function of TensorOp(s) and store the result in new_tensor_table.
|
// Perform the compute function of TensorOp(s) and store the result in new_tensor_table.
|
||||||
RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), new_tensor_table.get(), job_list));
|
RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), new_tensor_table.get(), job_list));
|
||||||
|
@ -281,9 +301,9 @@ Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_tabl
|
||||||
std::vector<TensorRow> result_table;
|
std::vector<TensorRow> result_table;
|
||||||
// Executing the list of jobs
|
// Executing the list of jobs
|
||||||
for (size_t i = 0; i < job_list.size(); i++) {
|
for (size_t i = 0; i < job_list.size(); i++) {
|
||||||
// Executre MapJob.
|
// Execute MapJob.
|
||||||
RETURN_IF_NOT_OK(job_list[i]->Run(job_input_table, &result_table));
|
RETURN_IF_NOT_OK(job_list[i]->Run(job_input_table, &result_table));
|
||||||
// Assign the pocessed data as an input for the next job processing, except for the last TensorOp in the list.
|
// Assign the processed data as an input for the next job processing, except for the last TensorOp in the list.
|
||||||
if (i + 1 < job_list.size()) {
|
if (i + 1 < job_list.size()) {
|
||||||
job_input_table = std::move(result_table);
|
job_input_table = std::move(result_table);
|
||||||
}
|
}
|
||||||
|
@ -428,5 +448,20 @@ Status MapOp::Accept(NodePass *p, bool *modified) {
|
||||||
// Downcast shared pointer then call visitor
|
// Downcast shared pointer then call visitor
|
||||||
return p->RunOnNode(shared_from_base<MapOp>(), modified);
|
return p->RunOnNode(shared_from_base<MapOp>(), modified);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status MapOp::PauseFromMaster() {
|
||||||
|
// reset num_paused workers to 0
|
||||||
|
num_workers_paused_ = 0;
|
||||||
|
for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
|
||||||
|
// a special buffer (id=-1, empty, none flag) is used to signal that worker needs to pause.
|
||||||
|
RETURN_IF_NOT_OK(local_queues_[wkr_id]->Add(
|
||||||
|
std::make_unique<MapWorkerJob>(std::make_unique<DataBuffer>(-1, DataBuffer::kDeBFlagNone))));
|
||||||
|
}
|
||||||
|
// wait until all workers are done processing their work in local_queue_
|
||||||
|
RETURN_IF_NOT_OK(master_pause_wp_.Wait());
|
||||||
|
// clear the WaitPost for the next Wait()
|
||||||
|
master_pause_wp_.Clear();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -16,15 +16,19 @@
|
||||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_MAP_OP_H_
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_MAP_OP_H_
|
||||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_MAP_OP_H_
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_MAP_OP_H_
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/callback/ds_callback.h"
|
||||||
|
#include "minddata/dataset/engine/datasetops/map_op/map_job.h"
|
||||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||||
#include "minddata/dataset/kernels/tensor_op.h"
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
#include "minddata/dataset/util/queue.h"
|
#include "minddata/dataset/util/queue.h"
|
||||||
#include "minddata/dataset/engine/datasetops/map_op/map_job.h"
|
#include "minddata/dataset/util/wait_post.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
@ -108,6 +112,13 @@ class MapOp : public ParallelOp {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Setter method.
|
||||||
|
// @return Builder setter method returns reference to the builder.
|
||||||
|
Builder &AddCallbacks(const std::vector<std::shared_ptr<DSCallback>> &callbacks) {
|
||||||
|
builder_callbacks_.insert(builder_callbacks_.end(), callbacks.begin(), callbacks.end());
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
// The builder "build" method creates the final object.
|
// The builder "build" method creates the final object.
|
||||||
// @param ptr The shared_ptr to the new MapOp object
|
// @param ptr The shared_ptr to the new MapOp object
|
||||||
// @return Status
|
// @return Status
|
||||||
|
@ -116,6 +127,7 @@ class MapOp : public ParallelOp {
|
||||||
private:
|
private:
|
||||||
std::vector<std::string> build_in_col_names_;
|
std::vector<std::string> build_in_col_names_;
|
||||||
std::vector<std::string> build_out_col_names_;
|
std::vector<std::string> build_out_col_names_;
|
||||||
|
std::vector<std::shared_ptr<DSCallback>> builder_callbacks_;
|
||||||
std::vector<std::shared_ptr<TensorOp>> build_tensor_funcs_;
|
std::vector<std::shared_ptr<TensorOp>> build_tensor_funcs_;
|
||||||
int32_t build_num_workers_;
|
int32_t build_num_workers_;
|
||||||
int32_t build_op_connector_size_;
|
int32_t build_op_connector_size_;
|
||||||
|
@ -186,6 +198,7 @@ class MapOp : public ParallelOp {
|
||||||
// A unit of job for map worker thread.
|
// A unit of job for map worker thread.
|
||||||
// MapWorkerJob holds a list of MapJob where each MapJob can be a CpuMapJob, GpuMapJob or DvppMapJob.
|
// MapWorkerJob holds a list of MapJob where each MapJob can be a CpuMapJob, GpuMapJob or DvppMapJob.
|
||||||
struct MapWorkerJob {
|
struct MapWorkerJob {
|
||||||
|
explicit MapWorkerJob(std::unique_ptr<DataBuffer> db) : databuffer(std::move(db)) {}
|
||||||
std::vector<std::shared_ptr<MapJob>> jobs;
|
std::vector<std::shared_ptr<MapJob>> jobs;
|
||||||
std::unique_ptr<DataBuffer> databuffer;
|
std::unique_ptr<DataBuffer> databuffer;
|
||||||
};
|
};
|
||||||
|
@ -215,6 +228,12 @@ class MapOp : public ParallelOp {
|
||||||
// Indices of the columns to process.
|
// Indices of the columns to process.
|
||||||
std::vector<size_t> to_process_indices_;
|
std::vector<size_t> to_process_indices_;
|
||||||
|
|
||||||
|
// wait post used to perform the pausing logic in MapOp
|
||||||
|
WaitPost master_pause_wp_;
|
||||||
|
|
||||||
|
// count number of workers that have signaled master
|
||||||
|
std::atomic_int num_workers_paused_;
|
||||||
|
|
||||||
// Private function for worker/thread to loop continuously. It comprises the main
|
// Private function for worker/thread to loop continuously. It comprises the main
|
||||||
// logic of MapOp: getting the data from previous Op, validating user specified column names,
|
// logic of MapOp: getting the data from previous Op, validating user specified column names,
|
||||||
// applying a list of TensorOps to each of the data, process the results and then
|
// applying a list of TensorOps to each of the data, process the results and then
|
||||||
|
@ -247,6 +266,13 @@ class MapOp : public ParallelOp {
|
||||||
// Private function for initializing private variables such as in_columns_, out_columns_.
|
// Private function for initializing private variables such as in_columns_, out_columns_.
|
||||||
// @return - Status
|
// @return - Status
|
||||||
Status InitPrivateVariable(std::unordered_map<std::string, int32_t> *col_name_id_map);
|
Status InitPrivateVariable(std::unordered_map<std::string, int32_t> *col_name_id_map);
|
||||||
|
|
||||||
|
// This function should only be called from master thread. It intends to suspend the operation of all workers and
|
||||||
|
// have them wait on the QueueList. Master thread would send a token to each worker then wait on a WaitPost.
|
||||||
|
// Workers upon receiving the suspension token from master thread, increment an atomic count, the last worker
|
||||||
|
// who does the increment wakes up the master.
|
||||||
|
// @return - Status
|
||||||
|
Status PauseFromMaster() override;
|
||||||
};
|
};
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -34,7 +34,7 @@ class Semaphore {
|
||||||
/// \brief Decrement the internal counter. Will be blocked if the value is 0.
|
/// \brief Decrement the internal counter. Will be blocked if the value is 0.
|
||||||
/// \return Error code. Can get interrupt.
|
/// \return Error code. Can get interrupt.
|
||||||
Status P();
|
Status P();
|
||||||
/// \brief Increment the internal counter. Wakeup on of the watiers if any.
|
/// \brief Increment the internal counter. Wakeup on of the waiters if any.
|
||||||
void V();
|
void V();
|
||||||
/// \brief Peek the internal value
|
/// \brief Peek the internal value
|
||||||
/// \return The internal value
|
/// \return The internal value
|
||||||
|
|
|
@ -59,6 +59,13 @@ namespace dataset {
|
||||||
} \
|
} \
|
||||||
} while (false)
|
} while (false)
|
||||||
|
|
||||||
|
#define RETURN_OK_IF_TRUE(_condition) \
|
||||||
|
do { \
|
||||||
|
if (_condition) { \
|
||||||
|
return Status::OK(); \
|
||||||
|
} \
|
||||||
|
} while (false)
|
||||||
|
|
||||||
enum class StatusCode : char {
|
enum class StatusCode : char {
|
||||||
kOK = 0,
|
kOK = 0,
|
||||||
kOutOfMemory = 1,
|
kOutOfMemory = 1,
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""init file for python callback"""
|
||||||
|
from .ds_callback import DSCallback, WaitedDSCallback
|
||||||
|
|
||||||
|
__all__ = ["DSCallback", "WaitedDSCallback"]
|
|
@ -0,0 +1,232 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
Python callback class
|
||||||
|
"""
|
||||||
|
import threading
|
||||||
|
from mindspore._c_dataengine import PyDSCallback
|
||||||
|
from mindspore.train.callback import Callback
|
||||||
|
from .validators import check_callback
|
||||||
|
|
||||||
|
|
||||||
|
class DSCallback:
|
||||||
|
"""
|
||||||
|
Abstract base class used to build a dataset callback class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_size (int, optional): The number of steps before the step_begin and step_end are called (Default=1).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> class PrintInfo(DSCallback):
|
||||||
|
>>> def ds_epoch_end(self, ds_run_context):
|
||||||
|
>>> print(cb_params.cur_epoch_num)
|
||||||
|
>>> print(cb_params.cur_step_num)
|
||||||
|
>>>
|
||||||
|
>>> data = data.map(operations=op, callbacks=PrintInfo())
|
||||||
|
"""
|
||||||
|
|
||||||
|
@check_callback
|
||||||
|
def __init__(self, step_size=1):
|
||||||
|
self.step_size = step_size
|
||||||
|
|
||||||
|
def ds_begin(self, ds_run_context):
|
||||||
|
"""
|
||||||
|
Called before the data pipeline is started.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds_run_context (RunContext): Include some information of the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def ds_epoch_begin(self, ds_run_context):
|
||||||
|
"""
|
||||||
|
Called before a new epoch is started.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds_run_context (RunContext): Include some information of the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def ds_epoch_end(self, ds_run_context):
|
||||||
|
"""
|
||||||
|
Called after an epoch is finished.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds_run_context (RunContext): Include some information of the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def ds_step_begin(self, ds_run_context):
|
||||||
|
"""
|
||||||
|
Called before n steps are started.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds_run_context (RunContext): Include some information of the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def ds_step_end(self, ds_run_context):
|
||||||
|
"""
|
||||||
|
Called after n steps are finished.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds_run_context (RunContext): Include some information of the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def create_runtime_obj(self):
|
||||||
|
"""
|
||||||
|
Creates a runtime (C++) object from the callback methods defined by the user.
|
||||||
|
|
||||||
|
Returns: _c_dataengine.PyDSCallback
|
||||||
|
"""
|
||||||
|
c_cb = PyDSCallback(self.step_size)
|
||||||
|
at_least_one = False
|
||||||
|
|
||||||
|
if self.__class__.ds_begin != DSCallback.ds_begin:
|
||||||
|
c_cb.set_begin(self.ds_begin)
|
||||||
|
at_least_one = True
|
||||||
|
|
||||||
|
if self.__class__.ds_epoch_begin != DSCallback.ds_epoch_begin:
|
||||||
|
c_cb.set_epoch_begin(self.ds_epoch_begin)
|
||||||
|
at_least_one = True
|
||||||
|
if self.__class__.ds_epoch_end != DSCallback.ds_epoch_end:
|
||||||
|
c_cb.set_epoch_end(self.ds_epoch_end)
|
||||||
|
at_least_one = True
|
||||||
|
|
||||||
|
if self.__class__.ds_step_begin != DSCallback.ds_step_begin:
|
||||||
|
c_cb.set_step_begin(self.ds_step_begin)
|
||||||
|
at_least_one = True
|
||||||
|
if self.__class__.ds_step_end != DSCallback.ds_step_end:
|
||||||
|
c_cb.set_step_end(self.ds_step_end)
|
||||||
|
at_least_one = True
|
||||||
|
|
||||||
|
if not at_least_one:
|
||||||
|
raise AttributeError("Provided Callback class did not override any of the 6 callback methods.")
|
||||||
|
|
||||||
|
return c_cb
|
||||||
|
|
||||||
|
|
||||||
|
class WaitedDSCallback(Callback, DSCallback):
|
||||||
|
"""
|
||||||
|
Abstract base class used to build a dataset callback class that are synchronized with the training callback.
|
||||||
|
|
||||||
|
This class can be used to execute a user defined logic right after the previous step or epoch.
|
||||||
|
For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> my_cb = MyWaitedCallback(32)
|
||||||
|
>>> data = data.map(operations=AugOp(), callbacks=my_cb)
|
||||||
|
>>> data = data.batch(32)
|
||||||
|
>>> # define the model
|
||||||
|
>>> model.train(epochs, data, callbacks=[my_cb])
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_size: the number of rows in each step.
|
||||||
|
Usually the step size will be equal to the batch size (Default=1)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, step_size=1):
|
||||||
|
super().__init__()
|
||||||
|
self.step_size = step_size
|
||||||
|
self.step_event = threading.Event()
|
||||||
|
self.step_run_context = None
|
||||||
|
|
||||||
|
self.epoch_event = threading.Event()
|
||||||
|
self.epoch_run_context = None
|
||||||
|
|
||||||
|
def sync_epoch_begin(self, train_run_context, ds_run_context):
|
||||||
|
"""
|
||||||
|
Called before a new dataset epoch is started and after the previous training epoch is ended.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_run_context: Include some information of the model with feedback from the previous epoch.
|
||||||
|
ds_run_context: Include some information of the dataset pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def sync_step_begin(self, train_run_context, ds_run_context):
|
||||||
|
"""
|
||||||
|
Called before a new dataset step is started and after the previous training step is ended.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_run_context: Include some information of the model with feedback from the previous step.
|
||||||
|
ds_run_context: Include some information of the dataset pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def epoch_end(self, run_context):
|
||||||
|
"""
|
||||||
|
Internal method, do not call/override. Defines epoch_end of Callback to release the wait in ds_epoch_begin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context: Include some information of the model.
|
||||||
|
"""
|
||||||
|
self.epoch_run_context = run_context
|
||||||
|
self.epoch_event.set()
|
||||||
|
self.epoch_event.clear()
|
||||||
|
|
||||||
|
def ds_epoch_begin(self, ds_run_context):
|
||||||
|
"""
|
||||||
|
Internal method, do not call/override. Defines ds_epoch_begin of DSCallback to wait for MS epoch_end callback.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds_run_context: Include some information of the pipeline.
|
||||||
|
"""
|
||||||
|
if ds_run_context.cur_epoch_num > 1:
|
||||||
|
if self.epoch_run_context is None:
|
||||||
|
self.epoch_event.wait()
|
||||||
|
self.sync_epoch_begin(self.epoch_run_context, ds_run_context)
|
||||||
|
self.epoch_run_context = None
|
||||||
|
|
||||||
|
def step_end(self, run_context):
|
||||||
|
"""
|
||||||
|
Internal method, do not call/override. Defines step_end of Callback to release the wait in ds_step_begin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_context: Include some information of the model.
|
||||||
|
"""
|
||||||
|
self.step_run_context = run_context
|
||||||
|
self.step_event.set()
|
||||||
|
self.step_event.clear()
|
||||||
|
|
||||||
|
def ds_step_begin(self, ds_run_context):
|
||||||
|
"""
|
||||||
|
Internal method, do not call/override. Defines ds_step_begin of DSCallback to wait for MS step_end callback.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ds_run_context: Include some information of the pipeline.
|
||||||
|
"""
|
||||||
|
if ds_run_context.cur_step_num > self.step_size:
|
||||||
|
if self.step_run_context is None:
|
||||||
|
self.step_event.wait()
|
||||||
|
self.sync_step_begin(self.step_run_context, ds_run_context)
|
||||||
|
self.step_run_context = None
|
||||||
|
|
||||||
|
def create_runtime_obj(self):
|
||||||
|
"""
|
||||||
|
Creates a runtime (C++) object from the callback methods defined by the user. This method is internal.
|
||||||
|
|
||||||
|
Returns: _c_dataengine.PyDSCallback
|
||||||
|
"""
|
||||||
|
c_cb = PyDSCallback(self.step_size)
|
||||||
|
at_least_one = False
|
||||||
|
|
||||||
|
if self.__class__.sync_step_begin != WaitedDSCallback.sync_step_begin:
|
||||||
|
c_cb.set_step_begin(self.ds_step_begin)
|
||||||
|
at_least_one = True
|
||||||
|
|
||||||
|
if self.__class__.sync_epoch_begin != WaitedDSCallback.sync_epoch_begin:
|
||||||
|
c_cb.set_epoch_begin(self.ds_epoch_begin)
|
||||||
|
at_least_one = True
|
||||||
|
|
||||||
|
if not at_least_one:
|
||||||
|
raise AttributeError("Provided Callback class did not override any of the 2 callback methods.")
|
||||||
|
|
||||||
|
return c_cb
|
|
@ -0,0 +1,34 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License foNtest_resr the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""
|
||||||
|
Built-in validators.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from ..core.validator_helpers import parse_user_args, check_pos_int32
|
||||||
|
|
||||||
|
|
||||||
|
def check_callback(method):
|
||||||
|
"""check the input arguments of DSCallback."""
|
||||||
|
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(self, *args, **kwargs):
|
||||||
|
[step_size], _ = parse_user_args(method, *args, **kwargs)
|
||||||
|
check_pos_int32(step_size, "step_size")
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return new_method
|
|
@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
||||||
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
|
||||||
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
|
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
|
||||||
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
|
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
|
||||||
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset,\
|
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, \
|
||||||
check_paddeddataset
|
check_paddeddataset
|
||||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||||
from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE
|
from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE
|
||||||
|
@ -395,7 +395,7 @@ class Dataset:
|
||||||
|
|
||||||
@check_map
|
@check_map
|
||||||
def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
|
def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
|
||||||
num_parallel_workers=None, python_multiprocessing=False, cache=None):
|
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None):
|
||||||
"""
|
"""
|
||||||
Apply each operation in operations to this dataset.
|
Apply each operation in operations to this dataset.
|
||||||
|
|
||||||
|
@ -438,6 +438,8 @@ class Dataset:
|
||||||
option could be beneficial if the python operation is computational heavy (default=False).
|
option could be beneficial if the python operation is computational heavy (default=False).
|
||||||
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
|
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
|
||||||
The cache feature is under development and is not recommended.
|
The cache feature is under development and is not recommended.
|
||||||
|
callbacks: (DSCallback, list[DSCallback], optional): list of Dataset callbacks to be called (Default=None).
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MapDataset, dataset after mapping operation.
|
MapDataset, dataset after mapping operation.
|
||||||
|
@ -552,7 +554,7 @@ class Dataset:
|
||||||
>>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order)
|
>>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order)
|
||||||
"""
|
"""
|
||||||
return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers,
|
return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers,
|
||||||
python_multiprocessing, cache)
|
python_multiprocessing, cache, callbacks)
|
||||||
|
|
||||||
@check_filter
|
@check_filter
|
||||||
def filter(self, predicate, input_columns=None, num_parallel_workers=1):
|
def filter(self, predicate, input_columns=None, num_parallel_workers=1):
|
||||||
|
@ -1548,6 +1550,7 @@ class DatasetOp(Dataset):
|
||||||
return self.children[0].get_class_indexing()
|
return self.children[0].get_class_indexing()
|
||||||
raise NotImplementedError("Dataset {} has not supported api get_class_indexing yet.".format(type(self)))
|
raise NotImplementedError("Dataset {} has not supported api get_class_indexing yet.".format(type(self)))
|
||||||
|
|
||||||
|
|
||||||
class BucketBatchByLengthDataset(DatasetOp):
|
class BucketBatchByLengthDataset(DatasetOp):
|
||||||
"""
|
"""
|
||||||
The result of applying BucketBatchByLength operator to the input dataset.
|
The result of applying BucketBatchByLength operator to the input dataset.
|
||||||
|
@ -1964,14 +1967,14 @@ class MapDataset(DatasetOp):
|
||||||
option could be beneficial if the python operation is computational heavy (default=False).
|
option could be beneficial if the python operation is computational heavy (default=False).
|
||||||
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
|
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
|
||||||
The cache feature is under development and is not recommended.
|
The cache feature is under development and is not recommended.
|
||||||
|
callbacks: (DSCallback, list[DSCallback], optional): list of Dataset callbacks to be called (Default=None)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified.
|
ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None,
|
def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None,
|
||||||
num_parallel_workers=None, python_multiprocessing=False, cache=None):
|
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None):
|
||||||
super().__init__(num_parallel_workers)
|
super().__init__(num_parallel_workers)
|
||||||
self.children.append(input_dataset)
|
self.children.append(input_dataset)
|
||||||
if input_columns is not None and not isinstance(input_columns, list):
|
if input_columns is not None and not isinstance(input_columns, list):
|
||||||
|
@ -1996,6 +1999,11 @@ class MapDataset(DatasetOp):
|
||||||
self.python_multiprocessing = python_multiprocessing
|
self.python_multiprocessing = python_multiprocessing
|
||||||
self.process_pool = None
|
self.process_pool = None
|
||||||
|
|
||||||
|
if callbacks is not None and not isinstance(callbacks, list):
|
||||||
|
callbacks = [callbacks]
|
||||||
|
|
||||||
|
self.callbacks = callbacks
|
||||||
|
|
||||||
def get_args(self):
|
def get_args(self):
|
||||||
args = super().get_args()
|
args = super().get_args()
|
||||||
args["input_columns"] = self.input_columns
|
args["input_columns"] = self.input_columns
|
||||||
|
@ -2003,6 +2011,9 @@ class MapDataset(DatasetOp):
|
||||||
args["output_columns"] = self.output_columns
|
args["output_columns"] = self.output_columns
|
||||||
args["columns_order"] = self.columns_order
|
args["columns_order"] = self.columns_order
|
||||||
args["cache"] = self.cache.cache_client if self.cache is not None else None
|
args["cache"] = self.cache.cache_client if self.cache is not None else None
|
||||||
|
|
||||||
|
if self.callbacks is not None:
|
||||||
|
args["callbacks"] = [cb.create_runtime_obj() for cb in self.callbacks]
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def get_dataset_size(self):
|
def get_dataset_size(self):
|
||||||
|
@ -2034,6 +2045,7 @@ class MapDataset(DatasetOp):
|
||||||
new_op.cache = copy.deepcopy(self.cache, memodict)
|
new_op.cache = copy.deepcopy(self.cache, memodict)
|
||||||
new_op.operations = self.operations
|
new_op.operations = self.operations
|
||||||
new_op.dataset_size = self.dataset_size
|
new_op.dataset_size = self.dataset_size
|
||||||
|
new_op.callbacks = self.callbacks
|
||||||
return new_op
|
return new_op
|
||||||
|
|
||||||
# Iterator bootstrap will be called on iterator construction.
|
# Iterator bootstrap will be called on iterator construction.
|
||||||
|
@ -2393,7 +2405,6 @@ class ConcatDataset(DatasetOp):
|
||||||
self._children_start_end_index_[index][0] = cumulative_samples_nums
|
self._children_start_end_index_[index][0] = cumulative_samples_nums
|
||||||
self._children_start_end_index_[index][1] = tem_value % sampler.num_shards
|
self._children_start_end_index_[index][1] = tem_value % sampler.num_shards
|
||||||
|
|
||||||
|
|
||||||
tem_sampler = copy.deepcopy(sampler)
|
tem_sampler = copy.deepcopy(sampler)
|
||||||
tem_sampler.set_offset(cumulative_samples_nums)
|
tem_sampler.set_offset(cumulative_samples_nums)
|
||||||
child.sampler = tem_sampler
|
child.sampler = tem_sampler
|
||||||
|
@ -2556,7 +2567,7 @@ class RangeDataset(MappableDataset):
|
||||||
|
|
||||||
def get_dataset_size(self):
|
def get_dataset_size(self):
|
||||||
if self.dataset_size is None:
|
if self.dataset_size is None:
|
||||||
self.dataset_size = math.ceil((self.stop - self.start)/self.step)
|
self.dataset_size = math.ceil((self.stop - self.start) / self.step)
|
||||||
return self.dataset_size
|
return self.dataset_size
|
||||||
|
|
||||||
|
|
||||||
|
@ -3423,7 +3434,7 @@ class GeneratorDataset(MappableDataset):
|
||||||
if not self.num_shards:
|
if not self.num_shards:
|
||||||
self.dataset_size = len(self.source)
|
self.dataset_size = len(self.source)
|
||||||
else:
|
else:
|
||||||
self.dataset_size = math.ceil(len(self.source)/self.num_shards)
|
self.dataset_size = math.ceil(len(self.source) / self.num_shards)
|
||||||
|
|
||||||
rows_from_sampler = self._get_sampler_dataset_size()
|
rows_from_sampler = self._get_sampler_dataset_size()
|
||||||
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
|
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
|
||||||
|
@ -5428,6 +5439,7 @@ class NumpySlicesDataset(GeneratorDataset):
|
||||||
num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
|
num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
|
||||||
num_shards=num_shards, shard_id=shard_id)
|
num_shards=num_shards, shard_id=shard_id)
|
||||||
|
|
||||||
|
|
||||||
class _PaddedDataset:
|
class _PaddedDataset:
|
||||||
"""
|
"""
|
||||||
Mainly for combining false samples provided by users into a dataset.
|
Mainly for combining false samples provided by users into a dataset.
|
||||||
|
@ -5435,6 +5447,7 @@ class _PaddedDataset:
|
||||||
Args:
|
Args:
|
||||||
padded_samples (list(dict)): the data provided by user to added to initial Dataset
|
padded_samples (list(dict)): the data provided by user to added to initial Dataset
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, padded_samples):
|
def __init__(self, padded_samples):
|
||||||
self.column_names = list(padded_samples[0].keys())
|
self.column_names = list(padded_samples[0].keys())
|
||||||
self.padded_samples = padded_samples
|
self.padded_samples = padded_samples
|
||||||
|
@ -5445,6 +5458,7 @@ class _PaddedDataset:
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.padded_samples)
|
return len(self.padded_samples)
|
||||||
|
|
||||||
|
|
||||||
class PaddedDataset(GeneratorDataset):
|
class PaddedDataset(GeneratorDataset):
|
||||||
"""
|
"""
|
||||||
Create a dataset with fake data provided by user. Mainly used to add to the original data set
|
Create a dataset with fake data provided by user. Mainly used to add to the original data set
|
||||||
|
@ -5463,6 +5477,7 @@ class PaddedDataset(GeneratorDataset):
|
||||||
>>> data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)}]
|
>>> data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)}]
|
||||||
>>> ds1 = ds.PaddedDataset(data1)
|
>>> ds1 = ds.PaddedDataset(data1)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@check_paddeddataset
|
@check_paddeddataset
|
||||||
def __init__(self, padded_samples):
|
def __init__(self, padded_samples):
|
||||||
dataset = _PaddedDataset(padded_samples)
|
dataset = _PaddedDataset(padded_samples)
|
||||||
|
|
|
@ -23,6 +23,7 @@ from functools import wraps
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore._c_expression import typing
|
from mindspore._c_expression import typing
|
||||||
|
from mindspore.dataset.callback import DSCallback
|
||||||
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
|
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
|
||||||
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
|
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
|
||||||
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \
|
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \
|
||||||
|
@ -31,6 +32,7 @@ from ..core.validator_helpers import parse_user_args, type_check, type_check_lis
|
||||||
from . import datasets
|
from . import datasets
|
||||||
from . import samplers
|
from . import samplers
|
||||||
from . import cache_client
|
from . import cache_client
|
||||||
|
from .. import callback
|
||||||
|
|
||||||
|
|
||||||
def check_imagefolderdatasetv2(method):
|
def check_imagefolderdatasetv2(method):
|
||||||
|
@ -247,6 +249,7 @@ def check_celebadataset(method):
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
def check_save(method):
|
def check_save(method):
|
||||||
"""A wrapper that wrap a parameter checker to the save op."""
|
"""A wrapper that wrap a parameter checker to the save op."""
|
||||||
|
|
||||||
|
@ -257,7 +260,7 @@ def check_save(method):
|
||||||
nreq_param_int = ['num_files']
|
nreq_param_int = ['num_files']
|
||||||
nreq_param_str = ['file_name', 'file_type']
|
nreq_param_str = ['file_name', 'file_type']
|
||||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||||
if(param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000):
|
if (param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000):
|
||||||
raise ValueError("num_files should between {} and {}.".format(1, 1000))
|
raise ValueError("num_files should between {} and {}.".format(1, 1000))
|
||||||
validate_dataset_param_value(nreq_param_str, param_dict, str)
|
validate_dataset_param_value(nreq_param_str, param_dict, str)
|
||||||
if param_dict.get('file_type') != 'mindrecord':
|
if param_dict.get('file_type') != 'mindrecord':
|
||||||
|
@ -265,6 +268,8 @@ def check_save(method):
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
def check_minddataset(method):
|
def check_minddataset(method):
|
||||||
"""A wrapper that wraps a parameter checker to the original Dataset(MindDataset)."""
|
"""A wrapper that wraps a parameter checker to the original Dataset(MindDataset)."""
|
||||||
|
|
||||||
|
@ -362,6 +367,7 @@ def check_generatordataset(method):
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
def check_random_dataset(method):
|
def check_random_dataset(method):
|
||||||
"""A wrapper that wraps a parameter checker to the original Dataset(RandomDataset)."""
|
"""A wrapper that wraps a parameter checker to the original Dataset(RandomDataset)."""
|
||||||
|
|
||||||
|
@ -545,7 +551,8 @@ def check_map(method):
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
[input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing, cache], _ = \
|
[input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing, cache,
|
||||||
|
callbacks], _ = \
|
||||||
parse_user_args(method, *args, **kwargs)
|
parse_user_args(method, *args, **kwargs)
|
||||||
|
|
||||||
nreq_param_columns = ['input_columns', 'output_columns']
|
nreq_param_columns = ['input_columns', 'output_columns']
|
||||||
|
@ -558,9 +565,17 @@ def check_map(method):
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
type_check(cache, (cache_client.DatasetCache,), "cache")
|
type_check(cache, (cache_client.DatasetCache,), "cache")
|
||||||
|
|
||||||
|
if callbacks is not None:
|
||||||
|
if isinstance(callbacks, (list, tuple)):
|
||||||
|
type_check_list(callbacks, (callback.DSCallback,), "callbacks")
|
||||||
|
else:
|
||||||
|
type_check(callbacks, (callback.DSCallback,), "callbacks")
|
||||||
|
|
||||||
for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]):
|
for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]):
|
||||||
if param is not None:
|
if param is not None:
|
||||||
check_columns(param, param_name)
|
check_columns(param, param_name)
|
||||||
|
if callbacks is not None:
|
||||||
|
type_check(callbacks, (list, DSCallback), "callbacks")
|
||||||
|
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ SET(DE_UT_SRCS
|
||||||
bounding_box_augment_op_test.cc
|
bounding_box_augment_op_test.cc
|
||||||
arena_test.cc
|
arena_test.cc
|
||||||
btree_test.cc
|
btree_test.cc
|
||||||
|
callback_test.cc
|
||||||
center_crop_op_test.cc
|
center_crop_op_test.cc
|
||||||
channel_swap_test.cc
|
channel_swap_test.cc
|
||||||
circular_pool_test.cc
|
circular_pool_test.cc
|
||||||
|
|
|
@ -0,0 +1,301 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <list>
|
||||||
|
|
||||||
|
#include "common/common.h"
|
||||||
|
#include "minddata/dataset/callback/ds_callback.h"
|
||||||
|
#include "minddata/dataset/core/client.h"
|
||||||
|
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
|
||||||
|
#include "minddata/dataset/kernels/data/no_op.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
using namespace mindspore::dataset;
|
||||||
|
using mindspore::LogStream;
|
||||||
|
using mindspore::MsLogLevel::INFO;
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
namespace test {
|
||||||
|
|
||||||
|
std::shared_ptr<ExecutionTree> BuildTree(std::vector<std::shared_ptr<DatasetOp>> ops) {
|
||||||
|
std::shared_ptr<ExecutionTree> tree = std::make_shared<ExecutionTree>();
|
||||||
|
Status rc;
|
||||||
|
for (int i = 0; i < ops.size(); i++) {
|
||||||
|
rc = tree->AssociateNode(ops[i]);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
if (i > 0) {
|
||||||
|
rc = ops[i]->AddChild(ops[i - 1]);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
}
|
||||||
|
if (i == ops.size() - 1) {
|
||||||
|
rc = tree->AssignRoot(ops[i]);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tree;
|
||||||
|
}
|
||||||
|
|
||||||
|
class TestCallback : public DSCallback {
|
||||||
|
public:
|
||||||
|
TestCallback(int32_t step_size)
|
||||||
|
: DSCallback(step_size),
|
||||||
|
begin_(true),
|
||||||
|
epoch_begin_(true),
|
||||||
|
step_begin_(true),
|
||||||
|
end_(true),
|
||||||
|
epoch_end_(true),
|
||||||
|
step_end_(true) {
|
||||||
|
all_names_.reserve(32);
|
||||||
|
all_step_nums_.reserve(32);
|
||||||
|
all_ep_nums_.reserve(32);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status DSBegin(const CallbackParam &cb_param) override {
|
||||||
|
all_names_.push_back("BGN");
|
||||||
|
all_step_nums_.push_back(cb_param.cur_step_num_);
|
||||||
|
all_ep_nums_.push_back(cb_param.cur_epoch_num_);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status DSEpochBegin(const CallbackParam &cb_param) override {
|
||||||
|
all_names_.push_back("EPBGN");
|
||||||
|
all_step_nums_.push_back(cb_param.cur_step_num_);
|
||||||
|
all_ep_nums_.push_back(cb_param.cur_epoch_num_);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status DSNStepBegin(const CallbackParam &cb_param) override {
|
||||||
|
all_names_.push_back("SPBGN");
|
||||||
|
all_step_nums_.push_back(cb_param.cur_step_num_);
|
||||||
|
all_ep_nums_.push_back(cb_param.cur_epoch_num_);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status DSEnd(const CallbackParam &cb_param) override {
|
||||||
|
all_names_.push_back("END");
|
||||||
|
all_step_nums_.push_back(cb_param.cur_step_num_);
|
||||||
|
all_ep_nums_.push_back(cb_param.cur_epoch_num_);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status DSEpochEnd(const CallbackParam &cb_param) override {
|
||||||
|
all_names_.push_back("EPEND");
|
||||||
|
all_step_nums_.push_back(cb_param.cur_step_num_);
|
||||||
|
all_ep_nums_.push_back(cb_param.cur_epoch_num_);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status DSNStepEnd(const CallbackParam &cb_param) override {
|
||||||
|
all_names_.push_back("SPEND");
|
||||||
|
all_step_nums_.push_back(cb_param.cur_step_num_);
|
||||||
|
all_ep_nums_.push_back(cb_param.cur_epoch_num_);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsBeginNeeded() override { return begin_; }
|
||||||
|
bool IsEpochBeginNeeded() override { return epoch_begin_; }
|
||||||
|
bool IsNStepBeginNeeded() override { return step_begin_; }
|
||||||
|
bool IsEndNeeded() override { return end_; }
|
||||||
|
bool IsEpochEndNeeded() override { return epoch_end_; }
|
||||||
|
bool IsNStepEndNeeded() override { return step_end_; }
|
||||||
|
|
||||||
|
std::vector<std::string> all_names(size_t len) {
|
||||||
|
return std::vector<std::string>(all_names_.begin(), all_names_.begin() + len);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> all_step_nums(size_t len) {
|
||||||
|
return std::vector<int64_t>(all_step_nums_.begin(), all_step_nums_.begin() + len);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> all_ep_nums(size_t len) {
|
||||||
|
return std::vector<int64_t>(all_ep_nums_.begin(), all_ep_nums_.begin() + len);
|
||||||
|
}
|
||||||
|
|
||||||
|
// flag for turning callback on and off
|
||||||
|
bool begin_, epoch_begin_, step_begin_, end_, epoch_end_, step_end_;
|
||||||
|
// name of the callback function in sequence, BGN, EPBGN, SPB, END, EPEND, SPEND
|
||||||
|
std::vector<std::string> all_names_;
|
||||||
|
std::vector<int64_t> all_step_nums_, all_ep_nums_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace test
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
class MindDataTestCallback : public UT::DatasetOpTesting {
|
||||||
|
public:
|
||||||
|
void SetUp() override {
|
||||||
|
DatasetOpTesting::SetUp();
|
||||||
|
GlobalInit();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(MindDataTestCallback, TestBasicCallback) {
|
||||||
|
// config callback
|
||||||
|
Status rc;
|
||||||
|
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(64);
|
||||||
|
std::shared_ptr<DSCallback> cb1 = tst_cb;
|
||||||
|
tst_cb->end_ = false; // don't do the end for now due to a timing issue
|
||||||
|
// config leaf_op, use random_data to avoid I/O
|
||||||
|
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||||
|
TensorShape shape({}); // empty shape is a 1-value scalar Tensor
|
||||||
|
ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape);
|
||||||
|
schema->AddColumn(col);
|
||||||
|
std::shared_ptr<RandomDataOp> leaf;
|
||||||
|
rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(44).Build(&leaf);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
// config mapOp
|
||||||
|
std::shared_ptr<MapOp> map_op;
|
||||||
|
auto map_b = MapOp::Builder();
|
||||||
|
rc = map_b.SetInColNames({"label"}).SetTensorFuncs({std::make_shared<NoOp>()}).AddCallbacks({cb1}).Build(&map_op);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
// config RepeatOp
|
||||||
|
std::shared_ptr<RepeatOp> repeat_op;
|
||||||
|
rc = RepeatOp::Builder(2).Build(&repeat_op);
|
||||||
|
// start build then launch tree
|
||||||
|
std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op});
|
||||||
|
rc = tree->Prepare();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = tree->Launch();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
// Start the loop of reading tensors from our pipeline
|
||||||
|
DatasetIterator di(tree);
|
||||||
|
TensorMap tensor_map;
|
||||||
|
rc = di.GetNextAsMap(&tensor_map);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
while (!tensor_map.empty()) {
|
||||||
|
rc = di.GetNextAsMap(&tensor_map);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> callback_names = {"BGN", "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND"};
|
||||||
|
std::vector<int64_t> all_steps = {0, 0, 1, 1, 65, 65, 88};
|
||||||
|
std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 1, 1};
|
||||||
|
// doing resize to make sure no unexpected epoch_end or extra epoch_begin is called
|
||||||
|
size_t len = 7;
|
||||||
|
EXPECT_EQ(tst_cb->all_names(len), callback_names);
|
||||||
|
EXPECT_EQ(tst_cb->all_step_nums(len), all_steps);
|
||||||
|
EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestCallback, TestMutiEpochCallback) {
|
||||||
|
// config callback
|
||||||
|
Status rc;
|
||||||
|
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(4);
|
||||||
|
std::shared_ptr<DSCallback> cb1 = tst_cb;
|
||||||
|
tst_cb->end_ = false; // don't do the end for now due to a timing issue
|
||||||
|
// config leaf_op, use random_data to avoid I/O
|
||||||
|
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||||
|
TensorShape shape({}); // empty shape is a 1-value scalar Tensor
|
||||||
|
ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape);
|
||||||
|
schema->AddColumn(col);
|
||||||
|
std::shared_ptr<RandomDataOp> leaf;
|
||||||
|
rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
// config mapOp
|
||||||
|
std::shared_ptr<MapOp> map_op;
|
||||||
|
auto map_b = MapOp::Builder();
|
||||||
|
rc = map_b.SetInColNames({"label"}).SetTensorFuncs({std::make_shared<NoOp>()}).AddCallbacks({cb1}).Build(&map_op);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
// config RepeatOp
|
||||||
|
std::shared_ptr<RepeatOp> repeat_op;
|
||||||
|
rc = RepeatOp::Builder(2).Build(&repeat_op);
|
||||||
|
// start build then launch tree
|
||||||
|
std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op});
|
||||||
|
rc = tree->Prepare();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = tree->Launch();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
// Start the loop of reading tensors from our pipeline
|
||||||
|
DatasetIterator di(tree);
|
||||||
|
TensorMap tensor_map;
|
||||||
|
size_t num_epochs = 2;
|
||||||
|
for (int ep_num = 0; ep_num < num_epochs; ++ep_num) {
|
||||||
|
di.GetNextAsMap(&tensor_map);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
while (tensor_map.size() != 0) {
|
||||||
|
rc = di.GetNextAsMap(&tensor_map);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> callback_names = {"BGN", "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND",
|
||||||
|
"EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND"};
|
||||||
|
|
||||||
|
std::vector<int64_t> all_steps = {0, 0, 1, 1, 5, 5, 8, 8, 9, 9, 13, 13, 16};
|
||||||
|
std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2};
|
||||||
|
|
||||||
|
size_t len = 13;
|
||||||
|
EXPECT_EQ(tst_cb->all_names(len), callback_names);
|
||||||
|
EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs);
|
||||||
|
EXPECT_EQ(tst_cb->all_step_nums(len), all_steps);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestCallback, TestSelectedCallback) {
|
||||||
|
// config callback
|
||||||
|
Status rc;
|
||||||
|
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(4);
|
||||||
|
std::shared_ptr<DSCallback> cb1 = tst_cb;
|
||||||
|
tst_cb->end_ = false;
|
||||||
|
// turn off the epochs
|
||||||
|
tst_cb->epoch_begin_ = false;
|
||||||
|
tst_cb->epoch_end_ = false;
|
||||||
|
|
||||||
|
// config leaf_op, use random_data to avoid I/O
|
||||||
|
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||||
|
TensorShape shape({}); // empty shape is a 1-value scalar Tensor
|
||||||
|
ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape);
|
||||||
|
schema->AddColumn(col);
|
||||||
|
std::shared_ptr<RandomDataOp> leaf;
|
||||||
|
rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
// config mapOp
|
||||||
|
std::shared_ptr<MapOp> map_op;
|
||||||
|
auto map_b = MapOp::Builder();
|
||||||
|
rc = map_b.SetInColNames({"label"}).SetTensorFuncs({std::make_shared<NoOp>()}).AddCallbacks({cb1}).Build(&map_op);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
// config RepeatOp
|
||||||
|
std::shared_ptr<RepeatOp> repeat_op;
|
||||||
|
rc = RepeatOp::Builder(2).Build(&repeat_op);
|
||||||
|
// start build then launch tree
|
||||||
|
std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op});
|
||||||
|
rc = tree->Prepare();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
rc = tree->Launch();
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
// Start the loop of reading tensors from our pipeline
|
||||||
|
DatasetIterator di(tree);
|
||||||
|
TensorMap tensor_map;
|
||||||
|
size_t num_epochs = 2;
|
||||||
|
for (int ep_num = 0; ep_num < num_epochs; ++ep_num) {
|
||||||
|
di.GetNextAsMap(&tensor_map);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
while (tensor_map.size() != 0) {
|
||||||
|
rc = di.GetNextAsMap(&tensor_map);
|
||||||
|
EXPECT_TRUE(rc.IsOk());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> callback_names = {"BGN", "SPBGN", "SPEND", "SPBGN", "SPEND",
|
||||||
|
"SPBGN", "SPEND", "SPBGN", "SPEND"};
|
||||||
|
|
||||||
|
std::vector<int64_t> all_steps = {0, 1, 1, 5, 5, 9, 9, 13, 13};
|
||||||
|
std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 2, 2, 2, 2};
|
||||||
|
|
||||||
|
size_t len = 9;
|
||||||
|
EXPECT_EQ(tst_cb->all_names(len), callback_names);
|
||||||
|
EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs);
|
||||||
|
EXPECT_EQ(tst_cb->all_step_nums(len), all_steps);
|
||||||
|
}
|
|
@ -0,0 +1,365 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
from builtins import range, super
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore import log as logger
|
||||||
|
from mindspore.dataset.callback import DSCallback, WaitedDSCallback
|
||||||
|
from mindspore.train import Model
|
||||||
|
from mindspore.train.callback import Callback
|
||||||
|
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.nn as nn
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||||
|
|
||||||
|
|
||||||
|
class MyDSCallback(DSCallback):
|
||||||
|
def __init__(self, step_size=1, events=None, cb_id=0):
|
||||||
|
super().__init__(step_size)
|
||||||
|
self.events = events
|
||||||
|
self.cb_id = cb_id
|
||||||
|
|
||||||
|
def append(self, event_name, ds_run_context):
|
||||||
|
event = [event_name, ds_run_context.cur_epoch_num,
|
||||||
|
ds_run_context.cur_step_num_in_epoch, ds_run_context.cur_step_num]
|
||||||
|
event = '_'.join([str(e) for e in event])
|
||||||
|
index = -1
|
||||||
|
for i, e in enumerate(self.events):
|
||||||
|
if e[0] == event:
|
||||||
|
index = i
|
||||||
|
break
|
||||||
|
if index != -1:
|
||||||
|
self.events[index][1].append(self.cb_id)
|
||||||
|
else:
|
||||||
|
self.events.append((event, [self.cb_id]))
|
||||||
|
|
||||||
|
def ds_begin(self, ds_run_context):
|
||||||
|
self.append("begin", ds_run_context)
|
||||||
|
|
||||||
|
def ds_end(self, ds_run_context):
|
||||||
|
self.append("end", ds_run_context)
|
||||||
|
|
||||||
|
def ds_epoch_begin(self, ds_run_context):
|
||||||
|
self.append("epoch_begin", ds_run_context)
|
||||||
|
|
||||||
|
def ds_epoch_end(self, ds_run_context):
|
||||||
|
self.append("epoch_end", ds_run_context)
|
||||||
|
|
||||||
|
def ds_step_begin(self, ds_run_context):
|
||||||
|
self.append("step_begin", ds_run_context)
|
||||||
|
|
||||||
|
def ds_step_end(self, ds_run_context):
|
||||||
|
self.append("step_end", ds_run_context)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_expected(epoch_num, step_num, step_size=1, map_num=1, repeat=1):
|
||||||
|
events = []
|
||||||
|
cb_id = list(range(map_num))
|
||||||
|
|
||||||
|
def append(name, e, s):
|
||||||
|
event = [name, e + 1, s + 1, e * step_num * repeat + s + 1]
|
||||||
|
event = '_'.join([str(ev) for ev in event])
|
||||||
|
events.append((event, cb_id))
|
||||||
|
|
||||||
|
events.append(("begin_0_0_0", cb_id))
|
||||||
|
for e in range(epoch_num):
|
||||||
|
append("epoch_begin", e, -1)
|
||||||
|
for s in range(step_num * repeat):
|
||||||
|
if s % step_size == 0:
|
||||||
|
append("step_begin", e, s)
|
||||||
|
append("step_end", e, s)
|
||||||
|
append("epoch_end", e, step_num * repeat - 1)
|
||||||
|
return events
|
||||||
|
|
||||||
|
|
||||||
|
def build_test_case_1cb(epochs, steps, step_size=1, repeat=1):
|
||||||
|
events = []
|
||||||
|
|
||||||
|
arr = list(range(1, steps + 1))
|
||||||
|
data = ds.NumpySlicesDataset(arr, shuffle=False)
|
||||||
|
|
||||||
|
my_cb = MyDSCallback(step_size=step_size, events=events)
|
||||||
|
|
||||||
|
data = data.map(operations=(lambda x: x), callbacks=my_cb)
|
||||||
|
if repeat != 1:
|
||||||
|
data = data.repeat(repeat)
|
||||||
|
itr = data.create_tuple_iterator(num_epochs=epochs)
|
||||||
|
for _ in range(epochs):
|
||||||
|
for _ in itr:
|
||||||
|
pass
|
||||||
|
|
||||||
|
expected_events = generate_expected(epochs, steps, step_size, 1, repeat)
|
||||||
|
assert expected_events == events
|
||||||
|
|
||||||
|
|
||||||
|
def build_test_case_2cbs(epochs, steps):
|
||||||
|
events1 = []
|
||||||
|
events2 = []
|
||||||
|
my_cb1 = MyDSCallback(events=events1)
|
||||||
|
my_cb2 = MyDSCallback(events=events2)
|
||||||
|
|
||||||
|
arr = list(range(1, steps + 1))
|
||||||
|
data = ds.NumpySlicesDataset(arr, shuffle=False)
|
||||||
|
|
||||||
|
data = data.map(operations=(lambda x: x), callbacks=[my_cb1, my_cb2])
|
||||||
|
|
||||||
|
itr = data.create_tuple_iterator(num_epochs=epochs)
|
||||||
|
for _ in range(epochs):
|
||||||
|
for _ in itr:
|
||||||
|
pass
|
||||||
|
|
||||||
|
expected_events = generate_expected(epochs, steps)
|
||||||
|
assert expected_events == events1
|
||||||
|
assert expected_events == events2
|
||||||
|
|
||||||
|
|
||||||
|
def build_test_case_2maps(epochs, steps):
|
||||||
|
events = []
|
||||||
|
my_cb1 = MyDSCallback(events=events, cb_id=0)
|
||||||
|
my_cb2 = MyDSCallback(events=events, cb_id=1)
|
||||||
|
|
||||||
|
arr = list(range(1, steps + 1))
|
||||||
|
data = ds.NumpySlicesDataset(arr, shuffle=False)
|
||||||
|
|
||||||
|
data = data.map(operations=(lambda x: x), callbacks=my_cb1)
|
||||||
|
data = data.map(operations=(lambda x: x), callbacks=my_cb2)
|
||||||
|
|
||||||
|
itr = data.create_tuple_iterator(num_epochs=epochs)
|
||||||
|
for _ in range(epochs):
|
||||||
|
for _ in itr:
|
||||||
|
pass
|
||||||
|
|
||||||
|
expected_events = generate_expected(epochs, steps, map_num=2)
|
||||||
|
|
||||||
|
assert expected_events[1:] == events[1:]
|
||||||
|
|
||||||
|
for event in events:
|
||||||
|
assert len(event) == 2
|
||||||
|
event, cb_ids = event
|
||||||
|
if event != "begin_0_0_0":
|
||||||
|
assert cb_ids[0] == 0
|
||||||
|
assert cb_ids[1] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_callbacks_all_methods():
|
||||||
|
logger.info("test_callbacks_all_methods")
|
||||||
|
|
||||||
|
build_test_case_1cb(1, 1)
|
||||||
|
build_test_case_1cb(1, 2)
|
||||||
|
build_test_case_1cb(1, 3)
|
||||||
|
build_test_case_1cb(1, 4)
|
||||||
|
|
||||||
|
build_test_case_1cb(2, 1)
|
||||||
|
build_test_case_1cb(2, 2)
|
||||||
|
build_test_case_1cb(2, 3)
|
||||||
|
build_test_case_1cb(2, 4)
|
||||||
|
|
||||||
|
build_test_case_1cb(3, 1)
|
||||||
|
build_test_case_1cb(3, 2)
|
||||||
|
build_test_case_1cb(3, 3)
|
||||||
|
build_test_case_1cb(3, 4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_callbacks_var_step_size():
|
||||||
|
logger.info("test_callbacks_var_step_size")
|
||||||
|
|
||||||
|
build_test_case_1cb(1, 2, 2)
|
||||||
|
build_test_case_1cb(1, 3, 2)
|
||||||
|
build_test_case_1cb(1, 4, 2)
|
||||||
|
|
||||||
|
build_test_case_1cb(2, 2, 2)
|
||||||
|
build_test_case_1cb(2, 3, 2)
|
||||||
|
build_test_case_1cb(2, 4, 2)
|
||||||
|
|
||||||
|
build_test_case_1cb(3, 2, 2)
|
||||||
|
build_test_case_1cb(3, 3, 2)
|
||||||
|
build_test_case_1cb(3, 4, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_callbacks_all_2cbs():
|
||||||
|
logger.info("test_callbacks_all_2cbs")
|
||||||
|
|
||||||
|
build_test_case_2cbs(4, 1)
|
||||||
|
build_test_case_2cbs(4, 2)
|
||||||
|
build_test_case_2cbs(4, 3)
|
||||||
|
build_test_case_2cbs(4, 4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_callbacks_2maps():
|
||||||
|
logger.info("test_callbacks_2maps")
|
||||||
|
|
||||||
|
build_test_case_2maps(5, 10)
|
||||||
|
|
||||||
|
build_test_case_2maps(6, 9)
|
||||||
|
|
||||||
|
|
||||||
|
class MyWaitedCallback(WaitedDSCallback):
|
||||||
|
def __init__(self, events, step_size=1):
|
||||||
|
super().__init__(step_size)
|
||||||
|
self.events = events
|
||||||
|
|
||||||
|
def sync_epoch_begin(self, train_run_context, ds_run_context):
|
||||||
|
event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
|
||||||
|
self.events.append(event)
|
||||||
|
|
||||||
|
def sync_step_begin(self, train_run_context, ds_run_context):
|
||||||
|
event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
|
||||||
|
self.events.append(event)
|
||||||
|
|
||||||
|
|
||||||
|
class MyMSCallback(Callback):
|
||||||
|
def __init__(self, events):
|
||||||
|
self.events = events
|
||||||
|
|
||||||
|
def epoch_end(self, run_context):
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
|
||||||
|
self.events.append(event)
|
||||||
|
|
||||||
|
def step_end(self, run_context):
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
|
||||||
|
self.events.append(event)
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def construct(self, x, y):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_non_sink():
|
||||||
|
logger.info("test_train_non_sink")
|
||||||
|
|
||||||
|
events = []
|
||||||
|
my_cb1 = MyWaitedCallback(events, 1)
|
||||||
|
my_cb2 = MyMSCallback(events)
|
||||||
|
arr = [1, 2, 3, 4]
|
||||||
|
data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
|
||||||
|
data = data.map(operations=(lambda x: x), callbacks=my_cb1)
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
model = Model(net)
|
||||||
|
|
||||||
|
model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
|
||||||
|
expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3',
|
||||||
|
'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4',
|
||||||
|
'ms_epoch_end_1_4', 'ds_epoch_begin_2_4',
|
||||||
|
'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6',
|
||||||
|
'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8',
|
||||||
|
'ms_step_end_2_8', 'ms_epoch_end_2_8']
|
||||||
|
|
||||||
|
assert events == expected_synced_events
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_batch_size2():
|
||||||
|
logger.info("test_train_batch_size2")
|
||||||
|
|
||||||
|
events = []
|
||||||
|
my_cb1 = MyWaitedCallback(events, 2)
|
||||||
|
my_cb2 = MyMSCallback(events)
|
||||||
|
arr = [1, 2, 3, 4]
|
||||||
|
data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
|
||||||
|
data = data.map(operations=(lambda x: x), callbacks=my_cb1)
|
||||||
|
data = data.batch(2)
|
||||||
|
net = Net()
|
||||||
|
model = Model(net)
|
||||||
|
|
||||||
|
model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
|
||||||
|
|
||||||
|
expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_3',
|
||||||
|
'ms_step_end_1_2',
|
||||||
|
'ms_epoch_end_1_2', 'ds_epoch_begin_2_4',
|
||||||
|
'ds_step_begin_2_5', 'ms_step_end_2_3', 'ds_step_begin_2_7',
|
||||||
|
'ms_step_end_2_4', 'ms_epoch_end_2_4']
|
||||||
|
|
||||||
|
assert events == expected_synced_events
|
||||||
|
|
||||||
|
|
||||||
|
def test_callbacks_validations():
|
||||||
|
logger.info("test_callbacks_validations")
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as err:
|
||||||
|
data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
|
||||||
|
data.map(operations=(lambda x: x), callbacks=0)
|
||||||
|
assert "Argument callbacks with value 0 is not " in str(err.value)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as err:
|
||||||
|
my_cb1 = MyDSCallback()
|
||||||
|
data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
|
||||||
|
data.map(operations=(lambda x: x), callbacks=[my_cb1, 0])
|
||||||
|
assert "Argument callbacks[1] with value 0 is not " in str(err.value)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as err:
|
||||||
|
class BadCB(DSCallback):
|
||||||
|
pass
|
||||||
|
|
||||||
|
my_cb = BadCB()
|
||||||
|
|
||||||
|
data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
|
||||||
|
data = data.map(operations=(lambda x: x), callbacks=my_cb)
|
||||||
|
for _ in data:
|
||||||
|
pass
|
||||||
|
assert "Provided Callback class did not override any of the 6 callback methods." in str(err.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_callback_sink_simulation():
|
||||||
|
logger.info("test_callback_sink_simulation")
|
||||||
|
|
||||||
|
events = []
|
||||||
|
epochs = 2
|
||||||
|
my_cb = MyWaitedCallback(events, 1)
|
||||||
|
data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
|
||||||
|
data = data.map(operations=(lambda x: x), callbacks=my_cb)
|
||||||
|
data = data.to_device()
|
||||||
|
data.send(num_epochs=epochs)
|
||||||
|
for e in range(epochs):
|
||||||
|
for s in range(4):
|
||||||
|
time.sleep(0.5)
|
||||||
|
events.append(f"ms_step_end_{e + 1}_{e * 4 + s + 1}")
|
||||||
|
my_cb.step_end(run_context=0)
|
||||||
|
events.append(f"ms_epoch_end_{e + 1}_{(e + 1) * 4}")
|
||||||
|
my_cb.epoch_end(run_context=0)
|
||||||
|
expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3',
|
||||||
|
'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4',
|
||||||
|
'ms_epoch_end_1_4', 'ds_epoch_begin_2_4',
|
||||||
|
'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6',
|
||||||
|
'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8',
|
||||||
|
'ms_step_end_2_8', 'ms_epoch_end_2_8']
|
||||||
|
|
||||||
|
assert events == expected_synced_events
|
||||||
|
|
||||||
|
|
||||||
|
def test_callbacks_repeat():
|
||||||
|
logger.info("test_callbacks_repeat")
|
||||||
|
|
||||||
|
build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2)
|
||||||
|
build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=3)
|
||||||
|
build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=3)
|
||||||
|
build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=3)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_callbacks_all_methods()
|
||||||
|
test_callbacks_all_2cbs()
|
||||||
|
test_callbacks_2maps()
|
||||||
|
test_callbacks_validations()
|
||||||
|
test_callbacks_var_step_size()
|
||||||
|
test_train_batch_size2()
|
||||||
|
test_callback_sink_simulation()
|
||||||
|
test_callbacks_repeat()
|
Loading…
Reference in New Issue