add tensor sync status
This commit is contained in:
parent
521e351dac
commit
5614b2ba6c
|
@ -410,7 +410,7 @@ void AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_i
|
|||
for (auto &pre_output : pre_output_tensors) {
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape());
|
||||
tensor->set_device_address(pre_output->device_address());
|
||||
tensor->set_dirty(false);
|
||||
tensor->set_sync_status(kNoNeedSync);
|
||||
outputs->emplace_back(tensor);
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -38,9 +38,9 @@ void UpdateOutputTensors(VectorRef *outputs,
|
|||
auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
|
||||
tensor->set_device_address(address);
|
||||
}
|
||||
if (tensor->need_sync()) {
|
||||
if (tensor->NeedSyncDeviceToHostImmediately()) {
|
||||
tensor->data_sync();
|
||||
tensor->set_need_sync(false);
|
||||
tensor->set_sync_status(kNoNeedSync);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -158,7 +158,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
|||
if (tensor_address == nullptr || tensor_address != device_address) {
|
||||
need_sync = true;
|
||||
}
|
||||
} else if (tensor->is_dirty() || tensor_address == nullptr) {
|
||||
} else if (tensor->NeedSyncHostToDevice() || tensor_address == nullptr) {
|
||||
need_sync = true;
|
||||
} else if (tensor_address != device_address) {
|
||||
if (tensor_address->DeviceType() == device_address->DeviceType()) {
|
||||
|
@ -177,7 +177,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
|||
}
|
||||
}
|
||||
}
|
||||
tensor->set_dirty(false);
|
||||
tensor->set_sync_status(kNoNeedSync);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -332,7 +332,7 @@ void GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info
|
|||
for (auto &pre_output : pre_output_tensors) {
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape());
|
||||
tensor->set_device_address(pre_output->device_address());
|
||||
tensor->set_dirty(false);
|
||||
tensor->set_sync_status(kNoNeedSync);
|
||||
outputs->emplace_back(tensor);
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -75,7 +75,7 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
|
|||
temp_shape.emplace_back(1);
|
||||
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
|
||||
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
|
||||
tensor->set_dirty(false);
|
||||
tensor->set_sync_status(kNoNeedSync);
|
||||
tensor->SetNeedWait(true);
|
||||
return tensor;
|
||||
}
|
||||
|
@ -96,12 +96,13 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
|
|||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
|
||||
ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
|
||||
tensor->set_need_sync(true);
|
||||
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
|
||||
} else {
|
||||
tensor->set_sync_status(kNeedSyncDeviceToHost);
|
||||
}
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
||||
tensor->SetNeedWait(true);
|
||||
}
|
||||
tensor->set_dirty(false);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
|
@ -198,7 +199,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
|
|||
auto *cur_val = static_cast<int32_t *>(cur_loop_tensor->data_c());
|
||||
MS_EXCEPTION_IF_NULL(cur_val);
|
||||
*cur_val = 0;
|
||||
cur_loop_tensor->set_dirty(true);
|
||||
cur_loop_tensor->set_sync_status(kNeedSyncHostToDevice);
|
||||
// set loop_count to zero
|
||||
MS_EXCEPTION_IF_NULL(inputs);
|
||||
inputs->push_back(cur_loop_tensor);
|
||||
|
@ -209,7 +210,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
|
|||
auto *next_val = static_cast<int32_t *>(next_loop_tensor->data_c());
|
||||
MS_EXCEPTION_IF_NULL(next_val);
|
||||
*next_val = 0;
|
||||
next_loop_tensor->set_dirty(true);
|
||||
next_loop_tensor->set_sync_status(kNeedSyncHostToDevice);
|
||||
// set loop_count to zero
|
||||
MS_EXCEPTION_IF_NULL(inputs);
|
||||
inputs->push_back(next_loop_tensor);
|
||||
|
@ -219,7 +220,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
|
|||
auto *epoch_val = static_cast<int32_t *>(epoch_tensor->data_c());
|
||||
MS_EXCEPTION_IF_NULL(epoch_val);
|
||||
*epoch_val = graph->current_epoch();
|
||||
epoch_tensor->set_dirty(true);
|
||||
epoch_tensor->set_sync_status(kNeedSyncHostToDevice);
|
||||
inputs->push_back(epoch_tensor);
|
||||
MS_LOG(INFO) << "Load epoch_val:" << *epoch_val;
|
||||
|
||||
|
@ -943,7 +944,7 @@ bool TensorNeedSync(const AnfNodePtr ¶meter, const tensor::TensorPtr &tensor
|
|||
if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
|
||||
return tensor->device_address().get() == nullptr || tensor->device_address() != device_address;
|
||||
}
|
||||
if (tensor->is_dirty()) {
|
||||
if (tensor->NeedSyncHostToDevice()) {
|
||||
return true;
|
||||
}
|
||||
if (tensor->device_address() != device_address) {
|
||||
|
@ -992,7 +993,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
||||
}
|
||||
}
|
||||
tensor->set_dirty(false);
|
||||
tensor->set_sync_status(kNoNeedSync);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1140,7 +1141,7 @@ void SessionBasic::Summary(KernelGraph *graph) {
|
|||
tensor->data_type(), tensor->data_c())) {
|
||||
MS_LOG(ERROR) << "Failed to sync output from device to host.";
|
||||
}
|
||||
tensor->set_dirty(false);
|
||||
tensor->set_sync_status(kNoNeedSync);
|
||||
params_list[output_item.first] = tensor;
|
||||
}
|
||||
// call callback function here
|
||||
|
|
|
@ -373,7 +373,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
|
|||
auto tensor = py::cast<tensor::TensorPtr>(input);
|
||||
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
|
||||
new_tensor->set_device_address(tensor->device_address());
|
||||
new_tensor->set_dirty(tensor->is_dirty());
|
||||
new_tensor->set_sync_status(tensor->sync_status());
|
||||
result[i] = new_tensor;
|
||||
}
|
||||
*status = PYNATIVE_SUCCESS;
|
||||
|
|
|
@ -162,7 +162,7 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k
|
|||
}
|
||||
if (bound_addresses_.find(address) != bound_addresses_.end()) {
|
||||
tensor->set_device_address(address);
|
||||
tensor->set_need_sync(true);
|
||||
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
|
||||
} else {
|
||||
if (infer_type_id != device_type_id) {
|
||||
size_t type_size = GetTypeByte(TypeIdToType(device_type_id));
|
||||
|
@ -170,15 +170,16 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k
|
|||
size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies<size_t>());
|
||||
address->ptr_ = resource_manager_.MemMalloc(tensor_size);
|
||||
tensor->set_device_address(address);
|
||||
tensor->set_need_sync(true);
|
||||
tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
|
||||
} else {
|
||||
tensor->set_device_address(nullptr);
|
||||
address->ptr_ = tensor->data_c();
|
||||
tensor->set_sync_status(kNoNeedSync);
|
||||
}
|
||||
address->ref_count_ = INIT_NODE_REF;
|
||||
(void)bound_addresses_.insert(address);
|
||||
}
|
||||
tensor->set_dirty(false);
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
|
@ -247,7 +248,7 @@ void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const
|
|||
tensor->data_c())) {
|
||||
MS_LOG(EXCEPTION) << "Parameter node sync host to device failed!";
|
||||
}
|
||||
tensor->set_dirty(true);
|
||||
tensor->set_sync_status(kNeedSyncHostToDevice);
|
||||
}
|
||||
address->ref_count_ = INIT_NODE_REF;
|
||||
tensor->set_device_address(address);
|
||||
|
|
|
@ -534,7 +534,7 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph
|
|||
auto pk_node = input_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
MS_EXCEPTION_IF_NULL(pk_node);
|
||||
if (tensor->is_dirty() || !pk_node->has_default()) {
|
||||
if (tensor->NeedSyncHostToDevice() || !pk_node->has_default()) {
|
||||
need_sync = true;
|
||||
}
|
||||
}
|
||||
|
@ -551,7 +551,7 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph
|
|||
return false;
|
||||
}
|
||||
}
|
||||
tensor->set_dirty(false);
|
||||
tensor->set_sync_status(kNoNeedSync);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -422,10 +422,9 @@ Tensor::Tensor(const Tensor &tensor)
|
|||
: MetaTensor(tensor),
|
||||
init_flag_(tensor.init_flag_),
|
||||
data_(tensor.data_),
|
||||
dirty_(tensor.dirty_),
|
||||
id_(tensor.id_),
|
||||
event_(tensor.event_),
|
||||
need_sync_(tensor.need_sync_),
|
||||
sync_status_(tensor.sync_status_),
|
||||
device_sync_(tensor.device_sync_),
|
||||
padding_type_(tensor.padding_type()) {}
|
||||
|
||||
|
@ -433,10 +432,9 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
|
|||
: MetaTensor(data_type, tensor.shape_),
|
||||
init_flag_(tensor.init_flag_),
|
||||
data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)),
|
||||
dirty_(tensor.dirty_),
|
||||
id_(tensor.id_),
|
||||
event_(tensor.event_),
|
||||
need_sync_(tensor.need_sync_),
|
||||
sync_status_(tensor.sync_status_),
|
||||
device_sync_(tensor.device_sync_),
|
||||
padding_type_(tensor.padding_type()) {}
|
||||
|
||||
|
@ -483,12 +481,11 @@ bool Tensor::ValueEqual(const Tensor &tensor) const {
|
|||
Tensor &Tensor::AssignValue(const Tensor &tensor) {
|
||||
if (this != &tensor) {
|
||||
MetaTensor::operator=(tensor);
|
||||
dirty_ = tensor.dirty_;
|
||||
device_sync_ = tensor.device_sync_;
|
||||
data_ = tensor.data_;
|
||||
id_ = tensor.id_;
|
||||
event_ = tensor.event_;
|
||||
need_sync_ = tensor.need_sync_;
|
||||
sync_status_ = tensor.sync_status_;
|
||||
padding_type_ = tensor.padding_type_;
|
||||
}
|
||||
return *this;
|
||||
|
|
|
@ -36,7 +36,7 @@
|
|||
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
|
||||
namespace mindspore {
|
||||
// brief mindspore::tensor namespace
|
||||
//
|
||||
enum TensorSyncStatus { kNoNeedSync, kNeedSyncHostToDevice, kNeedSyncDeviceToHost, kNeedSyncDeviceToHostImmediately };
|
||||
// A sub namespace in ME to support tensor related definition.
|
||||
namespace tensor {
|
||||
// Tensor data interface.
|
||||
|
@ -260,9 +260,6 @@ class Tensor : public MetaTensor {
|
|||
bool is_init() const { return init_flag_; }
|
||||
void set_init_flag(bool flag) { init_flag_ = flag; }
|
||||
|
||||
bool is_dirty() const { return dirty_; }
|
||||
void set_dirty(const bool dirty) { dirty_ = dirty; }
|
||||
|
||||
DeviceSyncPtr device_address() const { return device_sync_; }
|
||||
void set_device_address(const DeviceSyncPtr &device_sync) { device_sync_ = device_sync; }
|
||||
void set_padding_type(std::vector<Axis> padding_type) { padding_type_ = padding_type; }
|
||||
|
@ -293,17 +290,22 @@ class Tensor : public MetaTensor {
|
|||
event_ == nullptr;
|
||||
}
|
||||
|
||||
void set_need_sync(bool need_sync) { need_sync_ = need_sync; }
|
||||
void set_sync_status(TensorSyncStatus sync_status) { sync_status_ = sync_status; }
|
||||
|
||||
bool need_sync() const { return need_sync_; }
|
||||
TensorSyncStatus sync_status() const { return sync_status_; }
|
||||
|
||||
bool NeedSyncDeviceToHostImmediately() const { return sync_status_ == kNeedSyncDeviceToHostImmediately; }
|
||||
|
||||
bool NeedSyncDeviceToHost() const { return sync_status_ == kNeedSyncDeviceToHost; }
|
||||
|
||||
bool NeedSyncHostToDevice() const { return sync_status_ == kNeedSyncHostToDevice; }
|
||||
|
||||
private:
|
||||
bool init_flag_{false};
|
||||
TensorDataPtr data_{nullptr};
|
||||
bool dirty_{true};
|
||||
std::string id_{""};
|
||||
std::shared_ptr<WaitEvent> event_{nullptr};
|
||||
bool need_sync_{false};
|
||||
TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
|
||||
DeviceSyncPtr device_sync_{nullptr};
|
||||
std::vector<Axis> padding_type_;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue