unified runtime fix ref parameter bug and optimize running performance

This commit is contained in:
limingqi107 2022-01-18 16:49:28 +08:00
parent e4f1da32d7
commit 3e19123533
4 changed files with 42 additions and 42 deletions

View File

@ -294,19 +294,25 @@ bool HasAbstractRef(const AnfNodePtr &node) {
std::set<size_t> FetchModifiableRefInputIndex(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
// Only the auto moand node will modify the input.
if (!HasAbstractMonad(cnode)) {
return {};
}
bool has_monad = false;
std::set<size_t> ref_input_indexes;
for (size_t i = 1; i < cnode->size(); ++i) {
auto &input = cnode->inputs().at(i);
if (HasAbstractMonad(input)) {
has_monad = true;
}
if (HasAbstractRef(input)) {
(void)ref_input_indexes.insert(i - 1);
}
}
return ref_input_indexes;
// Only the auto moand node will modify the input.
if (has_monad) {
return ref_input_indexes;
} else {
return {};
}
}
std::set<size_t> FetchModifiableRefOutputIndex(const CNodePtr &cnode, const KernelGraphPtr &graph) {

View File

@ -127,7 +127,7 @@ class ActorDispatcher {
auto actor_manager = ActorMgr::GetActorMgrRef();
MS_EXCEPTION_IF_NULL(actor_manager);
auto base_actor = actor_manager->GetActor(aid);
T *actor = dynamic_cast<T *>(base_actor.get());
T *actor = static_cast<T *>(base_actor.get());
MS_EXCEPTION_IF_NULL(actor);
(actor->*method)(arg);
}
@ -143,7 +143,7 @@ class ActorDispatcher {
auto actor_manager = ActorMgr::GetActorMgrRef();
MS_EXCEPTION_IF_NULL(actor_manager);
auto base_actor = actor_manager->GetActor(aid);
T *actor = dynamic_cast<T *>(base_actor.get());
T *actor = static_cast<T *>(base_actor.get());
MS_EXCEPTION_IF_NULL(actor);
(actor->*method)(std::forward<Args1>(args)...);
}

View File

@ -338,24 +338,18 @@ void ControlActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, con
const AnfNodePtr &, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(output_data);
MS_EXCEPTION_IF_NULL(data_arrow);
MS_EXCEPTION_IF_NULL(context);
const auto &data = output_data->data_;
MS_EXCEPTION_IF_NULL(data);
auto formal_parameter_position = data_arrow->from_output_index_;
// Has no the ref formal parameter.
if (ref_formal_parameter_device_tensors_.count(formal_parameter_position) == 0) {
return;
}
if (data->GetMutablePtr() == nullptr) {
std::string error_info =
"The address of the " + std::to_string(formal_parameter_position) + "position real parameter is nullptr.";
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
if (data->ref_count() != SIZE_MAX) {
std::string error_info = "The ref count of the " + std::to_string(formal_parameter_position) +
"position real parameter is wrong:" + std::to_string(data->ref_count());
MS_EXCEPTION_IF_NULL(context);
const auto &data = output_data->data_;
MS_EXCEPTION_IF_NULL(data);
if ((data->GetMutablePtr() == nullptr) || (data->ref_count() != SIZE_MAX)) {
std::string error_info = "The address of the " + std::to_string(formal_parameter_position) +
"position real parameter is nullptr or ref count is wrong.";
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
@ -379,10 +373,11 @@ void ControlActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, con
// Copy from the real parameter to formal parameter and insert the device tensor copy store.
if ((device_tensor->format() != data->format()) || (device_tensor->DeviceType() != data->DeviceType())) {
MS_LOG(INFO) << "The formal parameter:" << formal_parameter.first->DebugString()
<< " input position:" << formal_parameter_position << " need copy from real parameter,"
<< " formal parameter format:" << device_tensor->format() << " type:" << device_tensor->DeviceType()
<< ", real parameter format:" << data->format() << " type:" << data->DeviceType();
MS_LOG(INFO) << GetAID().Name() << " the input position:" << formal_parameter_position
<< " copy from real parameter address:" << data << ", type:" << data->DeviceType()
<< ", format:" << data->format() << " to formal parameter address:" << device_tensor.get()
<< ", type:" << device_tensor->DeviceType() << ", format:" << device_tensor->format()
<< ", formal parameter name:" << formal_parameter.first->DebugString();
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{device_tensor->device_name(), device_tensor->device_id()});
MS_EXCEPTION_IF_NULL(device_context);

View File

@ -286,10 +286,9 @@ void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
new_device_tensor->GetSize());
}
MS_LOG(INFO) << GetAID().Name() << " the input position:" << input_data->index_
<< " copy from device type: " << input_data->data_->DeviceType()
<< ", device format: " << input_data->data_->format()
<< " to device type: " << new_device_tensor->DeviceType()
<< ", device format: " << new_device_tensor->format();
<< " copy from device address:" << input_data->data_ << ", type:" << input_data->data_->DeviceType()
<< ", format:" << input_data->data_->format() << " to device address:" << new_device_tensor.get()
<< ", type:" << new_device_tensor->DeviceType() << ", format:" << new_device_tensor->format();
// Copy from the real parameter to formal parameter and insert the device tensor copy store.
if (!Copy(new_device_tensor.get(), input_data->data_)) {
std::string error_info = "Copy device tensor failed: " + GetAID().Name();
@ -453,14 +452,14 @@ void KernelActor::RefreshDeviceTensorCopyStore(OpContext<DeviceTensor> *const co
auto &input_device_tensor = input_device_tensors_[ref_input_index];
MS_EXCEPTION_IF_NULL(input_device_tensor);
auto need_refreshed_device_tensors = DeviceTensorCopyStore::GetInstance().Fetch(input_device_tensor);
for (auto &need_refreshed_device_tensor : need_refreshed_device_tensors) {
MS_EXCEPTION_IF_NULL(need_refreshed_device_tensor);
for (auto &new_device_tensor : need_refreshed_device_tensors) {
MS_EXCEPTION_IF_NULL(new_device_tensor);
MS_LOG(INFO) << GetAID().Name() << " the input position:" << ref_input_index
<< " refresh from device type: " << input_device_tensor->DeviceType()
<< ", device format: " << input_device_tensor->format()
<< " to device type: " << need_refreshed_device_tensor->DeviceType()
<< ", device format: " << need_refreshed_device_tensor->format();
if (!Copy(need_refreshed_device_tensor, input_device_tensor)) {
<< " refresh from device address:" << input_device_tensor
<< ", type:" << input_device_tensor->DeviceType() << ", format:" << input_device_tensor->format()
<< " to device address:" << new_device_tensor << ", type:" << new_device_tensor->DeviceType()
<< ", format:" << new_device_tensor->format();
if (!Copy(new_device_tensor, input_device_tensor)) {
std::string error_info = "Copy input device tensor failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, error_info);
}
@ -471,17 +470,17 @@ void KernelActor::RefreshDeviceTensorCopyStore(OpContext<DeviceTensor> *const co
if (ref_output_index >= output_device_tensors_.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The output index is of range.");
}
auto &output_device_tensor = input_device_tensors_[ref_output_index];
auto &output_device_tensor = output_device_tensors_[ref_output_index];
MS_EXCEPTION_IF_NULL(output_device_tensor);
auto need_refreshed_device_tensors = DeviceTensorCopyStore::GetInstance().Fetch(output_device_tensor);
for (auto &need_refreshed_device_tensor : need_refreshed_device_tensors) {
MS_EXCEPTION_IF_NULL(need_refreshed_device_tensor);
for (auto &new_device_tensor : need_refreshed_device_tensors) {
MS_EXCEPTION_IF_NULL(new_device_tensor);
MS_LOG(INFO) << GetAID().Name() << " the output position:" << ref_output_index
<< " refresh from device type: " << output_device_tensor->DeviceType()
<< ", device format: " << output_device_tensor->format()
<< " to device type: " << need_refreshed_device_tensor->DeviceType()
<< ", device format: " << need_refreshed_device_tensor->format();
if (!Copy(need_refreshed_device_tensor, output_device_tensor)) {
<< " refresh from device address:" << output_device_tensor
<< ", type:" << output_device_tensor->DeviceType() << ", format:" << output_device_tensor->format()
<< " to device address:" << new_device_tensor << ", type:" << new_device_tensor->DeviceType()
<< ", format:" << new_device_tensor->format();
if (!Copy(new_device_tensor, output_device_tensor)) {
std::string error_info = "Copy output device tensor failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, error_info);
}