forked from mindspore-Ecosystem/mindspore
unified runtime fix ref parameter bug and optimize running performance
This commit is contained in:
parent
e4f1da32d7
commit
3e19123533
|
@ -294,19 +294,25 @@ bool HasAbstractRef(const AnfNodePtr &node) {
|
|||
|
||||
std::set<size_t> FetchModifiableRefInputIndex(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Only the auto moand node will modify the input.
|
||||
if (!HasAbstractMonad(cnode)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
bool has_monad = false;
|
||||
std::set<size_t> ref_input_indexes;
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto &input = cnode->inputs().at(i);
|
||||
if (HasAbstractMonad(input)) {
|
||||
has_monad = true;
|
||||
}
|
||||
if (HasAbstractRef(input)) {
|
||||
(void)ref_input_indexes.insert(i - 1);
|
||||
}
|
||||
}
|
||||
return ref_input_indexes;
|
||||
|
||||
// Only the auto moand node will modify the input.
|
||||
if (has_monad) {
|
||||
return ref_input_indexes;
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
std::set<size_t> FetchModifiableRefOutputIndex(const CNodePtr &cnode, const KernelGraphPtr &graph) {
|
||||
|
|
|
@ -127,7 +127,7 @@ class ActorDispatcher {
|
|||
auto actor_manager = ActorMgr::GetActorMgrRef();
|
||||
MS_EXCEPTION_IF_NULL(actor_manager);
|
||||
auto base_actor = actor_manager->GetActor(aid);
|
||||
T *actor = dynamic_cast<T *>(base_actor.get());
|
||||
T *actor = static_cast<T *>(base_actor.get());
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
(actor->*method)(arg);
|
||||
}
|
||||
|
@ -143,7 +143,7 @@ class ActorDispatcher {
|
|||
auto actor_manager = ActorMgr::GetActorMgrRef();
|
||||
MS_EXCEPTION_IF_NULL(actor_manager);
|
||||
auto base_actor = actor_manager->GetActor(aid);
|
||||
T *actor = dynamic_cast<T *>(base_actor.get());
|
||||
T *actor = static_cast<T *>(base_actor.get());
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
(actor->*method)(std::forward<Args1>(args)...);
|
||||
}
|
||||
|
|
|
@ -338,24 +338,18 @@ void ControlActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, con
|
|||
const AnfNodePtr &, OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(output_data);
|
||||
MS_EXCEPTION_IF_NULL(data_arrow);
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
const auto &data = output_data->data_;
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
auto formal_parameter_position = data_arrow->from_output_index_;
|
||||
// Has no the ref formal parameter.
|
||||
if (ref_formal_parameter_device_tensors_.count(formal_parameter_position) == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (data->GetMutablePtr() == nullptr) {
|
||||
std::string error_info =
|
||||
"The address of the " + std::to_string(formal_parameter_position) + "position real parameter is nullptr.";
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
if (data->ref_count() != SIZE_MAX) {
|
||||
std::string error_info = "The ref count of the " + std::to_string(formal_parameter_position) +
|
||||
"position real parameter is wrong:" + std::to_string(data->ref_count());
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
const auto &data = output_data->data_;
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
if ((data->GetMutablePtr() == nullptr) || (data->ref_count() != SIZE_MAX)) {
|
||||
std::string error_info = "The address of the " + std::to_string(formal_parameter_position) +
|
||||
"position real parameter is nullptr or ref count is wrong.";
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
|
||||
|
@ -379,10 +373,11 @@ void ControlActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, con
|
|||
|
||||
// Copy from the real parameter to formal parameter and insert the device tensor copy store.
|
||||
if ((device_tensor->format() != data->format()) || (device_tensor->DeviceType() != data->DeviceType())) {
|
||||
MS_LOG(INFO) << "The formal parameter:" << formal_parameter.first->DebugString()
|
||||
<< " input position:" << formal_parameter_position << " need copy from real parameter,"
|
||||
<< " formal parameter format:" << device_tensor->format() << " type:" << device_tensor->DeviceType()
|
||||
<< ", real parameter format:" << data->format() << " type:" << data->DeviceType();
|
||||
MS_LOG(INFO) << GetAID().Name() << " the input position:" << formal_parameter_position
|
||||
<< " copy from real parameter address:" << data << ", type:" << data->DeviceType()
|
||||
<< ", format:" << data->format() << " to formal parameter address:" << device_tensor.get()
|
||||
<< ", type:" << device_tensor->DeviceType() << ", format:" << device_tensor->format()
|
||||
<< ", formal parameter name:" << formal_parameter.first->DebugString();
|
||||
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{device_tensor->device_name(), device_tensor->device_id()});
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
|
|
|
@ -286,10 +286,9 @@ void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
|
|||
new_device_tensor->GetSize());
|
||||
}
|
||||
MS_LOG(INFO) << GetAID().Name() << " the input position:" << input_data->index_
|
||||
<< " copy from device type: " << input_data->data_->DeviceType()
|
||||
<< ", device format: " << input_data->data_->format()
|
||||
<< " to device type: " << new_device_tensor->DeviceType()
|
||||
<< ", device format: " << new_device_tensor->format();
|
||||
<< " copy from device address:" << input_data->data_ << ", type:" << input_data->data_->DeviceType()
|
||||
<< ", format:" << input_data->data_->format() << " to device address:" << new_device_tensor.get()
|
||||
<< ", type:" << new_device_tensor->DeviceType() << ", format:" << new_device_tensor->format();
|
||||
// Copy from the real parameter to formal parameter and insert the device tensor copy store.
|
||||
if (!Copy(new_device_tensor.get(), input_data->data_)) {
|
||||
std::string error_info = "Copy device tensor failed: " + GetAID().Name();
|
||||
|
@ -453,14 +452,14 @@ void KernelActor::RefreshDeviceTensorCopyStore(OpContext<DeviceTensor> *const co
|
|||
auto &input_device_tensor = input_device_tensors_[ref_input_index];
|
||||
MS_EXCEPTION_IF_NULL(input_device_tensor);
|
||||
auto need_refreshed_device_tensors = DeviceTensorCopyStore::GetInstance().Fetch(input_device_tensor);
|
||||
for (auto &need_refreshed_device_tensor : need_refreshed_device_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(need_refreshed_device_tensor);
|
||||
for (auto &new_device_tensor : need_refreshed_device_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(new_device_tensor);
|
||||
MS_LOG(INFO) << GetAID().Name() << " the input position:" << ref_input_index
|
||||
<< " refresh from device type: " << input_device_tensor->DeviceType()
|
||||
<< ", device format: " << input_device_tensor->format()
|
||||
<< " to device type: " << need_refreshed_device_tensor->DeviceType()
|
||||
<< ", device format: " << need_refreshed_device_tensor->format();
|
||||
if (!Copy(need_refreshed_device_tensor, input_device_tensor)) {
|
||||
<< " refresh from device address:" << input_device_tensor
|
||||
<< ", type:" << input_device_tensor->DeviceType() << ", format:" << input_device_tensor->format()
|
||||
<< " to device address:" << new_device_tensor << ", type:" << new_device_tensor->DeviceType()
|
||||
<< ", format:" << new_device_tensor->format();
|
||||
if (!Copy(new_device_tensor, input_device_tensor)) {
|
||||
std::string error_info = "Copy input device tensor failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, error_info);
|
||||
}
|
||||
|
@ -471,17 +470,17 @@ void KernelActor::RefreshDeviceTensorCopyStore(OpContext<DeviceTensor> *const co
|
|||
if (ref_output_index >= output_device_tensors_.size()) {
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The output index is of range.");
|
||||
}
|
||||
auto &output_device_tensor = input_device_tensors_[ref_output_index];
|
||||
auto &output_device_tensor = output_device_tensors_[ref_output_index];
|
||||
MS_EXCEPTION_IF_NULL(output_device_tensor);
|
||||
auto need_refreshed_device_tensors = DeviceTensorCopyStore::GetInstance().Fetch(output_device_tensor);
|
||||
for (auto &need_refreshed_device_tensor : need_refreshed_device_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(need_refreshed_device_tensor);
|
||||
for (auto &new_device_tensor : need_refreshed_device_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(new_device_tensor);
|
||||
MS_LOG(INFO) << GetAID().Name() << " the output position:" << ref_output_index
|
||||
<< " refresh from device type: " << output_device_tensor->DeviceType()
|
||||
<< ", device format: " << output_device_tensor->format()
|
||||
<< " to device type: " << need_refreshed_device_tensor->DeviceType()
|
||||
<< ", device format: " << need_refreshed_device_tensor->format();
|
||||
if (!Copy(need_refreshed_device_tensor, output_device_tensor)) {
|
||||
<< " refresh from device address:" << output_device_tensor
|
||||
<< ", type:" << output_device_tensor->DeviceType() << ", format:" << output_device_tensor->format()
|
||||
<< " to device address:" << new_device_tensor << ", type:" << new_device_tensor->DeviceType()
|
||||
<< ", format:" << new_device_tensor->format();
|
||||
if (!Copy(new_device_tensor, output_device_tensor)) {
|
||||
std::string error_info = "Copy output device tensor failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, error_info);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue