forked from mindspore-Ecosystem/mindspore
!4776 Introduce 2 extra ctrl flags to DataBuffer in dataset, address remaining cmts to PR4632
Merge pull request !4776 from ZiruiWu/map_callback_follow_up
This commit is contained in:
commit
9d7250c483
|
@ -45,6 +45,8 @@ PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) {
|
|||
.def("get_op_connector_size", &ConfigManager::op_connector_size)
|
||||
.def("get_seed", &ConfigManager::seed)
|
||||
.def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval)
|
||||
.def("get_callback_timeout", &ConfigManager::callback_timeout)
|
||||
.def("set_callback_timeout", &ConfigManager::set_callback_timeout)
|
||||
.def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); });
|
||||
}));
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ Status CallbackManager::Begin(const CallbackParam &cb_param) {
|
|||
// return Status::OK() if no begin is needed
|
||||
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||
|
||||
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||
RETURN_IF_NOT_OK(op_->WaitForWorkers());
|
||||
|
||||
// Now do the actual callback
|
||||
for (size_t ind : callback_inds) {
|
||||
|
@ -69,7 +69,7 @@ Status CallbackManager::EpochBegin(const CallbackParam &cb_param) {
|
|||
// return Status::OK() if no epoch_begin is needed
|
||||
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||
|
||||
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||
RETURN_IF_NOT_OK(op_->WaitForWorkers());
|
||||
|
||||
// Now do the actual callback
|
||||
for (size_t ind : callback_inds) {
|
||||
|
@ -89,7 +89,7 @@ Status CallbackManager::StepBegin(const CallbackParam &cb_param) {
|
|||
// return Status::OK() if no step_begin is needed
|
||||
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||
|
||||
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||
RETURN_IF_NOT_OK(op_->WaitForWorkers());
|
||||
|
||||
// Now do the actual callback
|
||||
for (size_t ind : callback_inds) {
|
||||
|
@ -108,7 +108,7 @@ Status CallbackManager::End(const CallbackParam &cb_param) {
|
|||
// return Status::OK() if no end is needed
|
||||
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||
|
||||
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||
RETURN_IF_NOT_OK(op_->WaitForWorkers());
|
||||
|
||||
// Now do the actual callback
|
||||
for (size_t ind : callback_inds) {
|
||||
|
@ -127,7 +127,7 @@ Status CallbackManager::EpochEnd(const CallbackParam &cb_param) {
|
|||
// return Status::OK() if no epoch_end is needed
|
||||
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||
|
||||
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||
RETURN_IF_NOT_OK(op_->WaitForWorkers());
|
||||
|
||||
// Now do the actual callback
|
||||
for (size_t ind : callback_inds) {
|
||||
|
@ -147,7 +147,7 @@ Status CallbackManager::StepEnd(const CallbackParam &cb_param) {
|
|||
// return Status::OK() if no step_end is needed
|
||||
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||
|
||||
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||
RETURN_IF_NOT_OK(op_->WaitForWorkers());
|
||||
|
||||
// Now do the actual callback
|
||||
for (size_t ind : callback_inds) {
|
||||
|
|
|
@ -32,7 +32,7 @@ class DatasetOp;
|
|||
/// This class manages all the callbacks that are associated with a single DatasetOp. For now, only MapOp supports this.
|
||||
class CallbackManager {
|
||||
public:
|
||||
/// CallbackManager default constructor. Init needs to be called before using the created instance.
|
||||
/// \brief CallbackManager default constructor. Init needs to be called before using the created instance.
|
||||
CallbackManager() : enabled_(false) {}
|
||||
|
||||
/// \brief
|
||||
|
|
|
@ -88,5 +88,8 @@ uint32_t ConfigManager::seed() const { return seed_; }
|
|||
void ConfigManager::set_seed(uint32_t seed) { seed_ = seed; }
|
||||
|
||||
void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_sampling_interval_ = interval; }
|
||||
|
||||
void ConfigManager::set_callback_timeout(uint32_t timeout) { callback_timout_ = timeout; }
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -116,9 +116,17 @@ class ConfigManager {
|
|||
void set_monitor_sampling_interval(uint32_t interval);
|
||||
|
||||
// getter function
|
||||
// @return The iterval of monitor sampling
|
||||
// @return The interval of monitor sampling
|
||||
int32_t monitor_sampling_interval() const { return monitor_sampling_interval_; }
|
||||
|
||||
// setter function
|
||||
// @param timeout - The setting to apply to the config
|
||||
void set_callback_timeout(uint32_t timeout);
|
||||
|
||||
// getter function
|
||||
// @return The timeout DSWaitedCallback would wait for before raising an error
|
||||
int32_t callback_timeout() const { return callback_timout_; }
|
||||
|
||||
private:
|
||||
int32_t rows_per_buffer_{kCfgRowsPerBuffer};
|
||||
int32_t num_parallel_workers_{kCfgParallelWorkers};
|
||||
|
@ -126,8 +134,9 @@ class ConfigManager {
|
|||
int32_t op_connector_size_{kCfgOpConnectorSize};
|
||||
uint32_t seed_{kCfgDefaultSeed};
|
||||
uint32_t monitor_sampling_interval_{kCfgMonitorSamplingInterval};
|
||||
uint32_t callback_timout_{kCfgCallbackTimeout};
|
||||
|
||||
// Private helper function that taks a nlohmann json format and populates the settings
|
||||
// Private helper function that takes a nlohmann json format and populates the settings
|
||||
// @param j - The json nlohmann json info
|
||||
Status FromJson(const nlohmann::json &j);
|
||||
};
|
||||
|
|
|
@ -68,6 +68,7 @@ constexpr uint32_t kCfgWorkerConnectorSize = 16;
|
|||
constexpr uint32_t kCfgOpConnectorSize = 16;
|
||||
constexpr uint32_t kCfgDefaultSeed = std::mt19937::default_seed;
|
||||
constexpr uint32_t kCfgMonitorSamplingInterval = 10;
|
||||
constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds
|
||||
|
||||
// Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h)
|
||||
constexpr uint8_t kCVInvalidType = 255;
|
||||
|
|
|
@ -38,7 +38,9 @@ class DataBuffer {
|
|||
enum BufferFlags : uint32_t {
|
||||
kDeBFlagNone = 0,
|
||||
kDeBFlagEOF = 1, // The buffer is an eof end-of-data msg
|
||||
kDeBFlagEOE = 1u << 1 // The buffer is an eoe end-of-epoch msg
|
||||
kDeBFlagEOE = 1u << 1, // The buffer is an eoe end-of-epoch msg
|
||||
kDeBFlagWait = 1u << 2, // The buffer is an control signal for workers to suspend operations
|
||||
kDeBFlagQuit = 1u << 3 // The buffer is a control signal for workers to quit
|
||||
};
|
||||
|
||||
// Name: Constructor #1
|
||||
|
@ -64,6 +66,10 @@ class DataBuffer {
|
|||
|
||||
bool eoe() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagEOE)); }
|
||||
|
||||
bool wait() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagWait)); }
|
||||
|
||||
bool quit() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagQuit)); }
|
||||
|
||||
// Simple getter funcs
|
||||
int32_t id() const { return buffer_id_; }
|
||||
|
||||
|
|
|
@ -363,10 +363,9 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// This function is only intended to be called by CallbackManager within the master thread of ParallelOp
|
||||
/// The expected behavior is this, when this function is invoked, this function will block until all the workers
|
||||
/// have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master.
|
||||
/// They would automatically wait on the QueueList when they are done. Hence, for now, a Unpause() function is not
|
||||
/// needed. Only parallelOp needs to override this function.
|
||||
/// They would automatically wait on the QueueList when they are done.
|
||||
/// \return Status
|
||||
virtual Status PauseFromMaster() { return Status::OK(); }
|
||||
virtual Status WaitForWorkers() { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
/// \brief Removes a parent operator from this operator
|
||||
|
|
|
@ -166,7 +166,7 @@ Status MapOp::operator()() {
|
|||
// init callback
|
||||
RETURN_IF_NOT_OK(callback_manager_.Init(shared_from_this()));
|
||||
Status rc = local_queues_.Register(tree_->AllTasks());
|
||||
RETURN_IF_NOT_OK(master_pause_wp_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
if (rc.IsError()) {
|
||||
TaskManager::FindMe()->Post();
|
||||
return rc;
|
||||
|
@ -205,23 +205,29 @@ Status MapOp::operator()() {
|
|||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
|
||||
}
|
||||
|
||||
// send the eoe buffer to worker
|
||||
|
||||
// reset epoch_step when a new epoch is about to start
|
||||
// check whether this is the end of a real epoch (not all eoe signals end of epoch)
|
||||
if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) {
|
||||
RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||
ep_step = 0;
|
||||
}
|
||||
// Propagate the eoe buffer to worker
|
||||
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff));
|
||||
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job)));
|
||||
UpdateRepeatAndEpochCounter();
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
|
||||
}
|
||||
// the last eoe increments the eoe count by 1, but this shouldn't be reflected on End() callback
|
||||
// End() is commented out because it might never be called due to the lack of EOF when EpochCtrl is -1
|
||||
// RETURN_IF_NOT_OK(callback_manager_.End(CallbackParam(op_current_epochs_, ep_step, total_step)));
|
||||
// handle eof logic
|
||||
// Handle eof logic, this code might never be reached if epoch_ctrl = -1.
|
||||
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff));
|
||||
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job)));
|
||||
|
||||
// Quit all workers, this code might never be reached if EpochCtrl is -1.
|
||||
for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
|
||||
auto quit = std::make_unique<MapWorkerJob>(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagQuit));
|
||||
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(quit)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -242,26 +248,27 @@ Status MapOp::WorkerEntry(int32_t worker_id) {
|
|||
// Map op does not use child iterator, and it needs to manually handle eoe and eof's itself
|
||||
// rather than use the base-class defaults.
|
||||
while (true) {
|
||||
// handle the pause logic. Pause is triggered when an buffer id of -1 with no special flag and no row is received
|
||||
if (in_buffer->id() == -1 && in_buffer->buffer_flags() == DataBuffer::kDeBFlagNone && in_buffer->NumRows() == 0) {
|
||||
// when worker receives the signal from master thread, it increments a atomic int
|
||||
// the last guy who increments the counter, wakes up master thread
|
||||
if (++num_workers_paused_ == num_workers_) master_pause_wp_.Set();
|
||||
// this will block the worker until master thread gives it a new work
|
||||
RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list));
|
||||
continue;
|
||||
// Handle special logic where buffer carries a ctrl flag.
|
||||
if (in_buffer->buffer_flags() != DataBuffer::kDeBFlagNone) {
|
||||
if (in_buffer->wait()) {
|
||||
// When worker receives the signal from master thread, it increments a atomic int
|
||||
// The last guy who increments the counter, wakes up master thread
|
||||
if (++num_workers_paused_ == num_workers_) {
|
||||
wait_for_workers_post_.Set();
|
||||
}
|
||||
// This will block the worker until master thread gives it a new work
|
||||
} else if (in_buffer->eoe()) {
|
||||
// Calling base class EoeReceived to forward eoe buffer.
|
||||
RETURN_IF_NOT_OK(EoeReceived(worker_id));
|
||||
// Fetch next data buffer and map job list
|
||||
RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list));
|
||||
continue;
|
||||
} else if (in_buffer->eof()) {
|
||||
// Calling base class EofReceived to forward eof buffer.
|
||||
RETURN_IF_NOT_OK(EofReceived(worker_id));
|
||||
} else if (in_buffer->quit()) {
|
||||
break;
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list));
|
||||
continue;
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(in_buffer->NumRows() * in_buffer->NumCols() != 0, "MapOp got an empty DataBuffer.");
|
||||
std::unique_ptr<TensorQTable> new_tensor_table(std::make_unique<TensorQTable>());
|
||||
// Perform the compute function of TensorOp(s) and store the result in new_tensor_table.
|
||||
|
@ -299,9 +306,9 @@ Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_tabl
|
|||
|
||||
// Variable to keep the result after executing the job.
|
||||
std::vector<TensorRow> result_table;
|
||||
// Executing the list of jobs
|
||||
// Executing the list of jobs.
|
||||
for (size_t i = 0; i < job_list.size(); i++) {
|
||||
// Execute MapJob.
|
||||
// Execute MapWorkerJob.
|
||||
RETURN_IF_NOT_OK(job_list[i]->Run(job_input_table, &result_table));
|
||||
// Assign the processed data as an input for the next job processing, except for the last TensorOp in the list.
|
||||
if (i + 1 < job_list.size()) {
|
||||
|
@ -311,8 +318,7 @@ Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_tabl
|
|||
|
||||
// Sanity check a row in result_table
|
||||
if (!result_table.empty() && out_columns_.size() != result_table[0].size()) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
"Result of a tensorOp doesn't match output column names");
|
||||
RETURN_STATUS_UNEXPECTED("Result of a tensorOp doesn't match output column names");
|
||||
}
|
||||
|
||||
// Merging the data processed by job (result_table) with the data that are not used.
|
||||
|
@ -386,7 +392,7 @@ Status MapOp::InitPrivateVariable(std::unordered_map<std::string, int32_t> *col_
|
|||
// columns from child are correct
|
||||
RETURN_IF_NOT_OK(this->ValidateInColumns(*col_name_id_map));
|
||||
|
||||
// initialize keep_input_columns, true means to keep the column.
|
||||
// Initialize keep_input_columns, true means to keep the column.
|
||||
keep_input_columns_.resize(col_name_id_map->size(), true);
|
||||
for (const auto &col_name : in_columns_) {
|
||||
int32_t missed = (*col_name_id_map)[col_name];
|
||||
|
@ -449,18 +455,18 @@ Status MapOp::Accept(NodePass *p, bool *modified) {
|
|||
return p->RunOnNode(shared_from_base<MapOp>(), modified);
|
||||
}
|
||||
|
||||
Status MapOp::PauseFromMaster() {
|
||||
Status MapOp::WaitForWorkers() {
|
||||
// reset num_paused workers to 0
|
||||
num_workers_paused_ = 0;
|
||||
for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
|
||||
// a special buffer (id=-1, empty, none flag) is used to signal that worker needs to pause.
|
||||
RETURN_IF_NOT_OK(local_queues_[wkr_id]->Add(
|
||||
std::make_unique<MapWorkerJob>(std::make_unique<DataBuffer>(-1, DataBuffer::kDeBFlagNone))));
|
||||
std::make_unique<MapWorkerJob>(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagWait))));
|
||||
}
|
||||
// wait until all workers are done processing their work in local_queue_
|
||||
RETURN_IF_NOT_OK(master_pause_wp_.Wait());
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Wait());
|
||||
// clear the WaitPost for the next Wait()
|
||||
master_pause_wp_.Clear();
|
||||
wait_for_workers_post_.Clear();
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -228,10 +228,10 @@ class MapOp : public ParallelOp {
|
|||
// Indices of the columns to process.
|
||||
std::vector<size_t> to_process_indices_;
|
||||
|
||||
// wait post used to perform the pausing logic in MapOp
|
||||
WaitPost master_pause_wp_;
|
||||
// Wait post used to perform the pausing logic in MapOp
|
||||
WaitPost wait_for_workers_post_;
|
||||
|
||||
// count number of workers that have signaled master
|
||||
// Count number of workers that have signaled master
|
||||
std::atomic_int num_workers_paused_;
|
||||
|
||||
// Private function for worker/thread to loop continuously. It comprises the main
|
||||
|
@ -272,7 +272,7 @@ class MapOp : public ParallelOp {
|
|||
// Workers upon receiving the suspension token from master thread, increment an atomic count, the last worker
|
||||
// who does the increment wakes up the master.
|
||||
// @return - Status
|
||||
Status PauseFromMaster() override;
|
||||
Status WaitForWorkers() override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,7 +34,7 @@ class Semaphore {
|
|||
/// \brief Decrement the internal counter. Will be blocked if the value is 0.
|
||||
/// \return Error code. Can get interrupt.
|
||||
Status P();
|
||||
/// \brief Increment the internal counter. Wakeup on of the waiters if any.
|
||||
/// \brief Increment the internal counter. Wake up on of the waiters if any.
|
||||
void V();
|
||||
/// \brief Peek the internal value
|
||||
/// \return The internal value
|
||||
|
|
|
@ -18,6 +18,7 @@ Python callback class
|
|||
import threading
|
||||
from mindspore._c_dataengine import PyDSCallback
|
||||
from mindspore.train.callback import Callback
|
||||
import mindspore.dataset as ds
|
||||
from .validators import check_callback
|
||||
|
||||
|
||||
|
@ -170,7 +171,6 @@ class WaitedDSCallback(Callback, DSCallback):
|
|||
"""
|
||||
self.epoch_run_context = run_context
|
||||
self.epoch_event.set()
|
||||
self.epoch_event.clear()
|
||||
|
||||
def ds_epoch_begin(self, ds_run_context):
|
||||
"""
|
||||
|
@ -180,10 +180,12 @@ class WaitedDSCallback(Callback, DSCallback):
|
|||
ds_run_context: Include some information of the pipeline.
|
||||
"""
|
||||
if ds_run_context.cur_epoch_num > 1:
|
||||
if self.epoch_run_context is None:
|
||||
self.epoch_event.wait()
|
||||
success = self.epoch_event.wait(timeout=ds.config.get_callback_timeout())
|
||||
self.epoch_event.clear()
|
||||
if not success:
|
||||
raise RuntimeError(f"ds_epoch_begin timed out after {ds.config.get_callback_timeout()} second(s)")
|
||||
# by the time this thread wakes up, self.epoch_run_context is already available
|
||||
self.sync_epoch_begin(self.epoch_run_context, ds_run_context)
|
||||
self.epoch_run_context = None
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
|
@ -194,7 +196,6 @@ class WaitedDSCallback(Callback, DSCallback):
|
|||
"""
|
||||
self.step_run_context = run_context
|
||||
self.step_event.set()
|
||||
self.step_event.clear()
|
||||
|
||||
def ds_step_begin(self, ds_run_context):
|
||||
"""
|
||||
|
@ -204,10 +205,12 @@ class WaitedDSCallback(Callback, DSCallback):
|
|||
ds_run_context: Include some information of the pipeline.
|
||||
"""
|
||||
if ds_run_context.cur_step_num > self.step_size:
|
||||
if self.step_run_context is None:
|
||||
self.step_event.wait()
|
||||
success = self.step_event.wait(timeout=ds.config.get_callback_timeout())
|
||||
self.step_event.clear()
|
||||
if not success:
|
||||
raise RuntimeError(f"ds_step_begin timed out after {ds.config.get_callback_timeout()} second(s)")
|
||||
# by the time this thread wakes up, self.epoch_run_context is already available
|
||||
self.sync_step_begin(self.step_run_context, ds_run_context)
|
||||
self.step_run_context = None
|
||||
|
||||
def create_runtime_obj(self):
|
||||
"""
|
||||
|
|
|
@ -157,6 +157,38 @@ def get_monitor_sampling_interval():
|
|||
return _config.get_monitor_sampling_interval()
|
||||
|
||||
|
||||
def set_callback_timeout(timeout):
|
||||
"""
|
||||
Set the default timeout (in seconds) for DSWaitedCallback.
|
||||
In case of a deadlock, the wait function will exit after the timeout period.
|
||||
|
||||
Args:
|
||||
timeout (int): timeout(s) to be used to end teh wait in DSWaitedCallback in case of a deadlock.
|
||||
|
||||
Raises:
|
||||
ValueError: If timeout is invalid (<= 0 or > MAX_INT_32).
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> # sets the new timout value.
|
||||
>>> ds.config.set_callback_timeout(100)
|
||||
"""
|
||||
if timeout <= 0 or timeout > INT32_MAX:
|
||||
raise ValueError("timeout given is not within the required range.")
|
||||
_config.set_callback_timeout(timeout)
|
||||
|
||||
|
||||
def get_callback_timeout():
|
||||
"""
|
||||
Get the default timeout for DSWaitedCallback.
|
||||
In case of a deadlock, the wait function will exit after the timeout period.
|
||||
|
||||
Returns:
|
||||
Int, the duration in seconds
|
||||
"""
|
||||
return _config.get_callback_timeout()
|
||||
|
||||
|
||||
def __str__():
|
||||
"""
|
||||
String representation of the configurations.
|
||||
|
|
|
@ -57,7 +57,7 @@ class TestCallback : public DSCallback {
|
|||
begin_(true),
|
||||
epoch_begin_(true),
|
||||
step_begin_(true),
|
||||
end_(true),
|
||||
end_(false),
|
||||
epoch_end_(true),
|
||||
step_end_(true) {
|
||||
all_names_.reserve(32);
|
||||
|
@ -145,7 +145,6 @@ TEST_F(MindDataTestCallback, TestBasicCallback) {
|
|||
Status rc;
|
||||
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(64);
|
||||
std::shared_ptr<DSCallback> cb1 = tst_cb;
|
||||
tst_cb->end_ = false; // don't do the end for now due to a timing issue
|
||||
// config leaf_op, use random_data to avoid I/O
|
||||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
TensorShape shape({}); // empty shape is a 1-value scalar Tensor
|
||||
|
@ -193,7 +192,6 @@ TEST_F(MindDataTestCallback, TestMutiEpochCallback) {
|
|||
Status rc;
|
||||
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(4);
|
||||
std::shared_ptr<DSCallback> cb1 = tst_cb;
|
||||
tst_cb->end_ = false; // don't do the end for now due to a timing issue
|
||||
// config leaf_op, use random_data to avoid I/O
|
||||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
TensorShape shape({}); // empty shape is a 1-value scalar Tensor
|
||||
|
@ -247,7 +245,6 @@ TEST_F(MindDataTestCallback, TestSelectedCallback) {
|
|||
Status rc;
|
||||
std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(4);
|
||||
std::shared_ptr<DSCallback> cb1 = tst_cb;
|
||||
tst_cb->end_ = false;
|
||||
// turn off the epochs
|
||||
tst_cb->epoch_begin_ = false;
|
||||
tst_cb->epoch_end_ = false;
|
||||
|
|
|
@ -29,7 +29,7 @@ import mindspore.nn as nn
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
class MyDSCallback(DSCallback):
|
||||
class BaseCallback(DSCallback):
|
||||
def __init__(self, step_size=1, events=None, cb_id=0):
|
||||
super().__init__(step_size)
|
||||
self.events = events
|
||||
|
@ -49,25 +49,36 @@ class MyDSCallback(DSCallback):
|
|||
else:
|
||||
self.events.append((event, [self.cb_id]))
|
||||
|
||||
|
||||
class Begin(BaseCallback):
|
||||
def ds_begin(self, ds_run_context):
|
||||
self.append("begin", ds_run_context)
|
||||
|
||||
def ds_end(self, ds_run_context):
|
||||
self.append("end", ds_run_context)
|
||||
|
||||
class EpochBegin(BaseCallback):
|
||||
def ds_epoch_begin(self, ds_run_context):
|
||||
self.append("epoch_begin", ds_run_context)
|
||||
|
||||
|
||||
class EpochEnd(BaseCallback):
|
||||
def ds_epoch_end(self, ds_run_context):
|
||||
self.append("epoch_end", ds_run_context)
|
||||
|
||||
|
||||
class StepBegin(BaseCallback):
|
||||
def ds_step_begin(self, ds_run_context):
|
||||
self.append("step_begin", ds_run_context)
|
||||
|
||||
|
||||
class StepEnd(BaseCallback):
|
||||
def ds_step_end(self, ds_run_context):
|
||||
self.append("step_end", ds_run_context)
|
||||
|
||||
|
||||
class MyDSCallback(Begin, EpochBegin, EpochEnd, StepBegin, StepEnd):
|
||||
pass
|
||||
|
||||
|
||||
def generate_expected(epoch_num, step_num, step_size=1, map_num=1, repeat=1):
|
||||
events = []
|
||||
cb_id = list(range(map_num))
|
||||
|
@ -98,6 +109,11 @@ def build_test_case_1cb(epochs, steps, step_size=1, repeat=1):
|
|||
|
||||
data = data.map(operations=(lambda x: x), callbacks=my_cb)
|
||||
if repeat != 1:
|
||||
if repeat % 2 == 0 and repeat != 2:
|
||||
data = data.repeat(2)
|
||||
data = data.map(operations=(lambda x: x))
|
||||
data = data.repeat(repeat // 2)
|
||||
else:
|
||||
data = data.repeat(repeat)
|
||||
itr = data.create_tuple_iterator(num_epochs=epochs)
|
||||
for _ in range(epochs):
|
||||
|
@ -201,11 +217,10 @@ def test_callbacks_all_2cbs():
|
|||
build_test_case_2cbs(4, 4)
|
||||
|
||||
|
||||
def test_callbacks_2maps():
|
||||
def skip_test_callbacks_2maps():
|
||||
logger.info("test_callbacks_2maps")
|
||||
|
||||
# This test case is skipped because in rare cases (25 out 1000) it might fail
|
||||
build_test_case_2maps(5, 10)
|
||||
|
||||
build_test_case_2maps(6, 9)
|
||||
|
||||
|
||||
|
@ -243,8 +258,8 @@ class Net(nn.Cell):
|
|||
return x
|
||||
|
||||
|
||||
def test_train_non_sink():
|
||||
logger.info("test_train_non_sink")
|
||||
def test_callbacks_non_sink():
|
||||
logger.info("test_callbacks_non_sink")
|
||||
|
||||
events = []
|
||||
my_cb1 = MyWaitedCallback(events, 1)
|
||||
|
@ -267,8 +282,8 @@ def test_train_non_sink():
|
|||
assert events == expected_synced_events
|
||||
|
||||
|
||||
def test_train_batch_size2():
|
||||
logger.info("test_train_batch_size2")
|
||||
def test_callbacks_non_sink_batch_size2():
|
||||
logger.info("test_callbacks_non_sink_batch_size2")
|
||||
|
||||
events = []
|
||||
my_cb1 = MyWaitedCallback(events, 2)
|
||||
|
@ -291,6 +306,27 @@ def test_train_batch_size2():
|
|||
assert events == expected_synced_events
|
||||
|
||||
|
||||
def test_callbacks_non_sink_mismatch_size():
|
||||
logger.info("test_callbacks_non_sink_mismatch_size")
|
||||
default_timeout = ds.config.get_callback_timeout()
|
||||
ds.config.set_callback_timeout(1)
|
||||
|
||||
events = []
|
||||
my_cb1 = MyWaitedCallback(events, 2)
|
||||
my_cb2 = MyMSCallback(events)
|
||||
arr = [1, 2, 3, 4]
|
||||
data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
|
||||
data = data.map(operations=(lambda x: x), callbacks=my_cb1)
|
||||
data = data.batch(3)
|
||||
net = Net()
|
||||
model = Model(net)
|
||||
with pytest.raises(Exception) as err:
|
||||
model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
|
||||
assert "RuntimeError: ds_step_begin timed out after 1 second(s)" in str(err.value)
|
||||
|
||||
ds.config.set_callback_timeout(default_timeout)
|
||||
|
||||
|
||||
def test_callbacks_validations():
|
||||
logger.info("test_callbacks_validations")
|
||||
|
||||
|
@ -318,7 +354,7 @@ def test_callbacks_validations():
|
|||
assert "Provided Callback class did not override any of the 6 callback methods." in str(err.value)
|
||||
|
||||
|
||||
def test_callback_sink_simulation():
|
||||
def test_callbacks_sink_simulation():
|
||||
logger.info("test_callback_sink_simulation")
|
||||
|
||||
events = []
|
||||
|
@ -353,13 +389,72 @@ def test_callbacks_repeat():
|
|||
build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=3)
|
||||
build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=3)
|
||||
|
||||
build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2)
|
||||
build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=4)
|
||||
build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=8)
|
||||
build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=16)
|
||||
|
||||
|
||||
def test_callbacks_exceptions():
|
||||
logger.info("test_callbacks_exceptions")
|
||||
|
||||
class BadCB(DSCallback):
|
||||
def ds_begin(self, ds_run_context):
|
||||
raise RuntimeError("Bad begin")
|
||||
|
||||
with pytest.raises(Exception) as err:
|
||||
data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
|
||||
data = data.map(operations=(lambda x: x), callbacks=BadCB())
|
||||
for _ in data:
|
||||
pass
|
||||
assert "RuntimeError: Bad begin" in str(err.value)
|
||||
|
||||
|
||||
def test_callbacks_one_cb():
|
||||
logger.info("test_callbacks_one_cb")
|
||||
|
||||
data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False)
|
||||
events1 = []
|
||||
events2 = []
|
||||
events3 = []
|
||||
my_begin = Begin(events=events1, cb_id=1)
|
||||
my_epoch_begin = EpochBegin(events=events2, cb_id=2)
|
||||
my_epoch_end = EpochEnd(events=events3, cb_id=3)
|
||||
my_step_begin = StepBegin(events=events3, cb_id=3)
|
||||
my_step_end = StepEnd(events=events2, cb_id=2)
|
||||
|
||||
data = data.map(operations=(lambda x: x), callbacks=my_begin)
|
||||
data = data.map(operations=(lambda x: x), callbacks=[my_epoch_begin, my_step_end])
|
||||
data = data.map(operations=(lambda x: x), callbacks=[my_epoch_end, my_step_begin])
|
||||
|
||||
itr = data.create_tuple_iterator()
|
||||
for _ in range(2):
|
||||
for _ in itr:
|
||||
pass
|
||||
expected_events1 = [('begin_0_0_0', [1])]
|
||||
expected_events2 = [('epoch_begin_1_0_0', [2]), ('step_end_1_1_1', [2]), ('step_end_1_2_2', [2]),
|
||||
('step_end_1_3_3', [2]), ('step_end_1_4_4', [2]), ('epoch_begin_2_0_4', [2]),
|
||||
('step_end_2_1_5', [2]), ('step_end_2_2_6', [2]), ('step_end_2_3_7', [2]),
|
||||
('step_end_2_4_8', [2])]
|
||||
expected_events3 = [('step_begin_1_1_1', [3]), ('step_begin_1_2_2', [3]), ('step_begin_1_3_3', [3]),
|
||||
('step_begin_1_4_4', [3]), ('epoch_end_1_4_4', [3]), ('step_begin_2_1_5', [3]),
|
||||
('step_begin_2_2_6', [3]), ('step_begin_2_3_7', [3]), ('step_begin_2_4_8', [3]),
|
||||
('epoch_end_2_4_8', [3])]
|
||||
assert events1 == expected_events1
|
||||
assert events2 == expected_events2
|
||||
assert events3 == expected_events3
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_callbacks_all_methods()
|
||||
skip_test_callbacks_2maps()
|
||||
test_callbacks_all_2cbs()
|
||||
test_callbacks_2maps()
|
||||
test_callbacks_all_methods()
|
||||
test_callbacks_exceptions()
|
||||
test_callbacks_repeat()
|
||||
test_callbacks_sink_simulation()
|
||||
test_callbacks_validations()
|
||||
test_callbacks_var_step_size()
|
||||
test_train_batch_size2()
|
||||
test_callback_sink_simulation()
|
||||
test_callbacks_repeat()
|
||||
test_callbacks_non_sink_batch_size2()
|
||||
test_callbacks_non_sink()
|
||||
test_callbacks_one_cb()
|
||||
test_callbacks_non_sink_mismatch_size()
|
||||
|
|
Loading…
Reference in New Issue