forked from mindspore-Ecosystem/mindspore
fix: dataset reused core when train stop first
This commit is contained in:
parent
629be861a2
commit
7b7b083224
|
@ -89,7 +89,7 @@ DeviceQueueOp::~DeviceQueueOp() {
|
|||
|
||||
#ifdef ENABLE_GPUQUE
|
||||
void DeviceQueueOp::ReleaseData(void *addr, int32_t worker_id) {
|
||||
if (addr != nullptr) {
|
||||
if (addr != nullptr && worker_id >= 0 && worker_id < pool_.size()) {
|
||||
pool_[worker_id]->Deallocate(addr);
|
||||
}
|
||||
}
|
||||
|
@ -277,15 +277,10 @@ Status DeviceQueueOp::SendDataToAscend() {
|
|||
#endif
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&curr_row));
|
||||
}
|
||||
if (curr_row.eoe() && send_epoch_end_) {
|
||||
TensorRow dummy_row;
|
||||
auto status =
|
||||
tdtInstancePtr->hostPush(dummy_row, is_profiling_enable, &tdt_cost, ACL_TENSOR_DATA_END_OF_SEQUENCE);
|
||||
|
||||
RETURN_IF_NOT_OK(CheckPushStatus(status, stop_send_, &send_finished_, &is_break_loop));
|
||||
MS_LOG(INFO) << "an epoch has already sent, now stop send data.";
|
||||
stop_send_ = true;
|
||||
}
|
||||
// send epoch end flag: ACL_TENSOR_DATA_END_OF_SEQUENCE to tdt
|
||||
RETURN_IF_NOT_OK(SendEpochEndToAscend(curr_row, is_profiling_enable, &tdt_cost, &is_break_loop));
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
if (is_profiling_enable) {
|
||||
connector_size = ChildOpConnectorSize();
|
||||
|
@ -307,6 +302,21 @@ Status DeviceQueueOp::SendDataToAscend() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeviceQueueOp::SendEpochEndToAscend(const TensorRow &curr_row, const bool &is_profiling_enable,
|
||||
int32_t *tdt_cost, bool *is_break_loop) {
|
||||
RETURN_UNEXPECTED_IF_NULL(tdt_cost);
|
||||
RETURN_UNEXPECTED_IF_NULL(is_break_loop);
|
||||
if (curr_row.eoe() && send_epoch_end_ && tdtInstancePtr->acl_handle_ != nullptr) {
|
||||
TensorRow dummy_row;
|
||||
auto status = tdtInstancePtr->hostPush(dummy_row, is_profiling_enable, tdt_cost, ACL_TENSOR_DATA_END_OF_SEQUENCE);
|
||||
|
||||
RETURN_IF_NOT_OK(CheckPushStatus(status, stop_send_, &send_finished_, is_break_loop));
|
||||
MS_LOG(INFO) << "an epoch has already sent, now stop send data.";
|
||||
stop_send_ = true;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void DeviceQueueOp::WaitContinueSignal() const {
|
||||
while (stop_send_ && ascend_keep_waiting_) {
|
||||
MS_LOG(DEBUG) << "stop_send flag is set, waiting for continue signal...";
|
||||
|
|
|
@ -139,6 +139,8 @@ class DeviceQueueOp : public PipelineOp {
|
|||
#ifdef ENABLE_TDTQUE
|
||||
void WaitContinueSignal() const;
|
||||
Status SendDataToAscend();
|
||||
Status SendEpochEndToAscend(const TensorRow &curr_row, const bool &is_profiling_enable, int32_t *tdt_cost,
|
||||
bool *is_break_loop);
|
||||
void LimitSendingBatches(int64_t send_batch, int64_t *sending_num, std::shared_ptr<ConfigManager> cfg);
|
||||
Status SendRowToTdt(TensorRow curr_row, bool is_profiling_enable, int32_t *tdt_cost);
|
||||
// check status that push data into device
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "minddata/dataset/engine/perf/profiling.h"
|
||||
#endif
|
||||
#include "minddata/dataset/util/log_adapter.h"
|
||||
#include "minddata/dataset/util/task_manager.h"
|
||||
#if ENABLE_D
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
#endif
|
||||
|
@ -84,6 +85,14 @@ Status TdtPlugin::hostPush(TensorRow ts_row, bool profiling, int32_t *time, aclt
|
|||
auto status = acltdtSendTensor(acl_handle_, acl_dataset, -1);
|
||||
DestroyAclDataset(acl_dataset);
|
||||
if (status != ACL_SUCCESS) {
|
||||
// if the device_queue thread had been interrupted by master, just print warning and return success
|
||||
if (mindspore::dataset::this_thread::is_interrupted()) {
|
||||
MS_LOG(WARNING) << "Device queue thread had been interrupted by TdtHandle::DestroyHandle, you can ignore "
|
||||
<< "the above error: 'failed to send...'. In this scenario, the training ends first without "
|
||||
<< "using all epoch(s) data, and the data preprocessing is blocked by the data "
|
||||
<< "transmission channel on the device side. So we force the data transmission channel to stop.";
|
||||
return Status::OK();
|
||||
}
|
||||
ReportErrorMessage();
|
||||
RETURN_STATUS_UNEXPECTED("Tdt Send data failed.");
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue