!29210 unified runtime support host and device in the control flow
Merge pull request !29210 from limingqi107/new_actor_runtime
This commit is contained in:
commit
e26c4824db
|
@ -57,6 +57,7 @@ void EnvironMgr::Clear() {
|
||||||
env.second->Clear();
|
env.second->Clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
env_handles_count_ = 0;
|
||||||
envs_.clear();
|
envs_.clear();
|
||||||
mutex.unlock();
|
mutex.unlock();
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include "mindrt/include/actor/op_actor.h"
|
#include "mindrt/include/actor/op_actor.h"
|
||||||
#include "runtime/framework/actor/actor_common.h"
|
#include "runtime/framework/actor/actor_common.h"
|
||||||
#include "runtime/framework/device_tensor_store.h"
|
#include "runtime/framework/device_tensor_store.h"
|
||||||
|
#include "runtime/framework/device_tensor_copy_store.h"
|
||||||
#include "runtime/hardware/device_context.h"
|
#include "runtime/hardware/device_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
|
@ -283,5 +283,51 @@ std::string FetchActorName(KernelTransformType kernel_type, const std::string &a
|
||||||
}
|
}
|
||||||
return actor_name;
|
return actor_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HasAbstractRef(const AnfNodePtr &node) {
|
||||||
|
if (node == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto &abs = node->abstract();
|
||||||
|
return (abs != nullptr) && abs->isa<abstract::AbstractRef>();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::set<size_t> FetchModifiableRefInputIndex(const CNodePtr &cnode) {
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
// Only the auto moand node will modify the input.
|
||||||
|
if (!HasAbstractMonad(cnode)) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::set<size_t> ref_input_indexes;
|
||||||
|
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||||
|
auto &input = cnode->inputs().at(i);
|
||||||
|
if (HasAbstractRef(input)) {
|
||||||
|
(void)ref_input_indexes.insert(i - 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ref_input_indexes;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::set<size_t> FetchModifiableRefOutputIndex(const CNodePtr &cnode, const KernelGraphPtr &graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
std::set<size_t> ref_output_indexes;
|
||||||
|
|
||||||
|
auto output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||||
|
for (size_t i = 0; i < output_num; ++i) {
|
||||||
|
session::AnfWithOutIndex output_pair(cnode, i);
|
||||||
|
// Only the ref node will modify the ref input corresponding to the output.
|
||||||
|
if (!graph->IsInRefOutputMap(output_pair)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto input_pair = graph->GetRefCorrespondOutput(output_pair);
|
||||||
|
MS_EXCEPTION_IF_NULL(input_pair.first);
|
||||||
|
if (HasAbstractRef(input_pair.first)) {
|
||||||
|
(void)ref_output_indexes.insert(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ref_output_indexes;
|
||||||
|
}
|
||||||
} // namespace runtime
|
} // namespace runtime
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <set>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
@ -214,6 +215,14 @@ KernelTransformType FetchKernelTransformType(const AnfNodePtr &node, const Kerne
|
||||||
GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
|
GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
|
||||||
std::string FetchActorName(KernelTransformType kernel_type, const std::string &actor_set_name,
|
std::string FetchActorName(KernelTransformType kernel_type, const std::string &actor_set_name,
|
||||||
const AnfNodePtr &node = nullptr, const KernelGraphPtr &graph = nullptr);
|
const AnfNodePtr &node = nullptr, const KernelGraphPtr &graph = nullptr);
|
||||||
|
|
||||||
|
// Check whether the parameter is a ref parameter.
|
||||||
|
bool HasAbstractRef(const AnfNodePtr &node);
|
||||||
|
|
||||||
|
// Fetch the input indexes which may be modified that exist in the input ref parameter.
|
||||||
|
std::set<size_t> FetchModifiableRefInputIndex(const CNodePtr &node);
|
||||||
|
// Fetch the output indexes which may be modified that exist in the ref node.
|
||||||
|
std::set<size_t> FetchModifiableRefOutputIndex(const CNodePtr &node, const KernelGraphPtr &graph);
|
||||||
} // namespace runtime
|
} // namespace runtime
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -149,6 +149,19 @@ void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) {
|
||||||
}
|
}
|
||||||
|
|
||||||
DumpAbstractActor(actor, ofs);
|
DumpAbstractActor(actor, ofs);
|
||||||
|
|
||||||
|
if (actor->modifiable_ref_input_indexes().size() != 0) {
|
||||||
|
ofs << "\t\tmodifiable_ref_input_indexes:" << actor->modifiable_ref_input_indexes().size() << "\n";
|
||||||
|
for (auto &ref_input_index : actor->modifiable_ref_input_indexes()) {
|
||||||
|
ofs << "\t\t\tmodifiable_ref_input_index:" << ref_input_index << "\n ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (actor->modifiable_ref_output_indexes().size() != 0) {
|
||||||
|
ofs << "\t\tmodifiable_ref_output_indexes:" << actor->modifiable_ref_output_indexes().size() << "\n";
|
||||||
|
for (auto &ref_output_index : actor->modifiable_ref_output_indexes()) {
|
||||||
|
ofs << "\t\t\tmodifiable_ref_output_index:" << ref_output_index << "\n ";
|
||||||
|
}
|
||||||
|
}
|
||||||
ofs << "\n";
|
ofs << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -215,10 +228,10 @@ void DumpControlActor(const ControlActor *actor, std::ofstream &ofs) {
|
||||||
DumpAbstractActor(actor, ofs);
|
DumpAbstractActor(actor, ofs);
|
||||||
const auto &local_partials = actor->local_partials();
|
const auto &local_partials = actor->local_partials();
|
||||||
if (local_partials.size() > 0) {
|
if (local_partials.size() > 0) {
|
||||||
ofs << "\t\t\tlocal partial num:" << local_partials.size() << "\n ";
|
ofs << "\t\tlocal partial num:" << local_partials.size() << "\n ";
|
||||||
for (const auto &local_partial : local_partials) {
|
for (const auto &local_partial : local_partials) {
|
||||||
MS_EXCEPTION_IF_NULL(local_partial.second->func_graph_);
|
MS_EXCEPTION_IF_NULL(local_partial.second->func_graph_);
|
||||||
ofs << "\t\t\t\tlocal partial index:" << local_partial.first
|
ofs << "\t\t\tlocal partial index:" << local_partial.first
|
||||||
<< "\tgraph:" << local_partial.second->func_graph_->ToString()
|
<< "\tgraph:" << local_partial.second->func_graph_->ToString()
|
||||||
<< "\tparameter num:" << local_partial.second->device_tensors_.size() << "\n";
|
<< "\tparameter num:" << local_partial.second->device_tensors_.size() << "\n";
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "runtime/framework/actor/control_flow/control_actor.h"
|
#include "runtime/framework/actor/control_flow/control_actor.h"
|
||||||
|
#include "runtime/hardware/device_context_manager.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace runtime {
|
namespace runtime {
|
||||||
|
@ -349,12 +350,12 @@ void ControlActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, con
|
||||||
|
|
||||||
if (data->GetMutablePtr() == nullptr) {
|
if (data->GetMutablePtr() == nullptr) {
|
||||||
std::string error_info =
|
std::string error_info =
|
||||||
"The address of the " + std::to_string(formal_parameter_position) + "position formal parameter is nullptr.";
|
"The address of the " + std::to_string(formal_parameter_position) + "position real parameter is nullptr.";
|
||||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||||
}
|
}
|
||||||
if (data->ref_count() != SIZE_MAX) {
|
if (data->ref_count() != SIZE_MAX) {
|
||||||
std::string error_info = "The ref count of the " + std::to_string(formal_parameter_position) +
|
std::string error_info = "The ref count of the " + std::to_string(formal_parameter_position) +
|
||||||
"position formal parameter is wrong:" + std::to_string(data->ref_count());
|
"position real parameter is wrong:" + std::to_string(data->ref_count());
|
||||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -364,19 +365,44 @@ void ControlActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, con
|
||||||
if ((device_tensor.get() == data) || (device_tensor->GetMutablePtr() == data->GetMutablePtr())) {
|
if ((device_tensor.get() == data) || (device_tensor->GetMutablePtr() == data->GetMutablePtr())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto real_parameter = device_tensor->GetNodeIndex();
|
auto formal_parameter = device_tensor->GetNodeIndex();
|
||||||
MS_EXCEPTION_IF_NULL(real_parameter.first);
|
MS_EXCEPTION_IF_NULL(formal_parameter.first);
|
||||||
if ((device_tensor->GetSize() != data->GetSize()) || (device_tensor->format() != data->format()) ||
|
if ((device_tensor->GetSize() != data->GetSize()) || (device_tensor->type_id() != data->type_id())) {
|
||||||
(device_tensor->type_id() != data->type_id())) {
|
|
||||||
std::string error_info =
|
std::string error_info =
|
||||||
"The address of the " + std::to_string(formal_parameter_position) +
|
"The formal parameter: " + formal_parameter.first->DebugString() +
|
||||||
"position formal parameter can not be set to real parameter:" + real_parameter.first->DebugString();
|
" position:" + std::to_string(formal_parameter_position) + "can not set from real parameter," +
|
||||||
|
" formal parameter size:" + std::to_string(device_tensor->GetSize()) +
|
||||||
|
" type id:" + std::to_string(device_tensor->type_id()) +
|
||||||
|
", real parameter size:" + std::to_string(data->GetSize()) + " type id:" + std::to_string(data->type_id());
|
||||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Copy from the real parameter to formal parameter and insert the device tensor copy store.
|
||||||
|
if ((device_tensor->format() != data->format()) || (device_tensor->DeviceType() != data->DeviceType())) {
|
||||||
|
MS_LOG(INFO) << "The formal parameter:" << formal_parameter.first->DebugString()
|
||||||
|
<< " input position:" << formal_parameter_position << " need copy from real parameter,"
|
||||||
|
<< " formal parameter format:" << device_tensor->format() << " type:" << device_tensor->DeviceType()
|
||||||
|
<< ", real parameter format:" << data->format() << " type:" << data->DeviceType();
|
||||||
|
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||||
|
{device_tensor->device_name(), device_tensor->device_id()});
|
||||||
|
MS_EXCEPTION_IF_NULL(device_context);
|
||||||
|
if ((device_tensor->GetPtr() == nullptr) &&
|
||||||
|
(!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize()))) {
|
||||||
|
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *device_context,
|
||||||
|
formal_parameter.first->DebugString(), device_tensor->GetSize());
|
||||||
|
}
|
||||||
|
if (!Copy(device_tensor.get(), data)) {
|
||||||
|
std::string error_info = "The formal parameter: " + formal_parameter.first->DebugString() +
|
||||||
|
" position:" + std::to_string(formal_parameter_position) + " copy failed.";
|
||||||
|
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||||
|
}
|
||||||
|
output_data->data_ = device_tensor.get();
|
||||||
|
DeviceTensorCopyStore::GetInstance().Insert(device_tensor.get(), data);
|
||||||
|
}
|
||||||
|
|
||||||
device_tensor->set_ptr(data->GetMutablePtr());
|
device_tensor->set_ptr(data->GetMutablePtr());
|
||||||
MS_LOG(DEBUG) << "Set the ptr: " << data->GetMutablePtr()
|
MS_LOG(DEBUG) << "Set the ptr: " << data->GetMutablePtr()
|
||||||
<< " for the ref real parameter: " << real_parameter.first->DebugString()
|
<< " for the ref formal parameter: " << formal_parameter.first->DebugString()
|
||||||
<< " in the actor: " << GetAID().Name();
|
<< " in the actor: " << GetAID().Name();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -177,12 +177,20 @@ void KernelActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
|
void KernelActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
|
||||||
|
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
|
||||||
if (strategy_ == GraphExecutionStrategy::kPipeline) {
|
if (strategy_ == GraphExecutionStrategy::kPipeline) {
|
||||||
ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &memory_free_list_, device_contexts_[0],
|
ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &memory_free_list_, device_contexts_[0],
|
||||||
context, GetAID());
|
context, GetAID());
|
||||||
} else {
|
} else {
|
||||||
FreeMemory(memory_free_list_, device_contexts_[0]);
|
FreeMemory(memory_free_list_, device_contexts_[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Free the address that is the temp store for kernel input copy.
|
||||||
|
for (auto ©_input_device_tensor : copy_input_device_tensors_) {
|
||||||
|
if ((copy_input_device_tensor != nullptr) && (copy_input_device_tensor->GetPtr() != nullptr)) {
|
||||||
|
device_contexts_[0]->FreeMemory(copy_input_device_tensor.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
|
void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
|
||||||
|
@ -248,39 +256,48 @@ void KernelActor::PushInputDeviceTensor(const std::vector<TensorPtr> *input_tens
|
||||||
void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
|
void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
|
||||||
OpContext<DeviceTensor> *const context) {
|
OpContext<DeviceTensor> *const context) {
|
||||||
MS_EXCEPTION_IF_NULL(input_data);
|
MS_EXCEPTION_IF_NULL(input_data);
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
MS_EXCEPTION_IF_NULL(input_data->data_);
|
||||||
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
|
const auto &device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel_, input_data->index_, false);
|
||||||
if ((input_data->data_ == nullptr) ||
|
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||||
(input_data->data_->DeviceType() == device_contexts_[0]->GetDeviceAddressType())) {
|
if ((input_data->data_->DeviceType() == device_tensor->DeviceType()) &&
|
||||||
|
(input_data->data_->format() == device_tensor->format())) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "Copy from device type: " << input_data->data_->DeviceType()
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
<< " to device type: " << device_contexts_[0]->GetDeviceAddressType() << " in " << GetAID().Name();
|
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
|
||||||
|
if (IntToSize(input_data->index_) >= copy_input_device_tensors_.size()) {
|
||||||
|
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The input index is of range.");
|
||||||
|
}
|
||||||
if (copy_input_device_tensors_[input_data->index_] == nullptr) {
|
if (copy_input_device_tensors_[input_data->index_] == nullptr) {
|
||||||
copy_input_device_tensors_[input_data->index_] = device_contexts_[0]->CreateDeviceAddress(
|
copy_input_device_tensors_[input_data->index_] = device_contexts_[0]->CreateDeviceAddress(
|
||||||
nullptr, input_data->data_->GetSize(), input_data->data_->format(), input_data->data_->type_id());
|
nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id());
|
||||||
}
|
}
|
||||||
|
auto &new_device_tensor = copy_input_device_tensors_[input_data->index_];
|
||||||
|
MS_EXCEPTION_IF_NULL(new_device_tensor);
|
||||||
// Dynamic shape need update size.
|
// Dynamic shape need update size.
|
||||||
copy_input_device_tensors_[input_data->index_]->SetSize(input_data->data_->GetSize());
|
new_device_tensor->SetSize(input_data->data_->GetSize());
|
||||||
|
// Update the input device tensor.
|
||||||
|
input_device_tensors_[input_data->index_] = new_device_tensor.get();
|
||||||
|
|
||||||
if (copy_input_device_tensors_[input_data->index_]->GetPtr() == nullptr) {
|
if ((new_device_tensor->GetPtr() == nullptr) &&
|
||||||
if (!device_contexts_[0]->AllocateMemory(copy_input_device_tensors_[input_data->index_].get(),
|
(!device_contexts_[0]->AllocateMemory(new_device_tensor.get(), new_device_tensor->GetSize()))) {
|
||||||
copy_input_device_tensors_[input_data->index_]->GetSize())) {
|
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(strategy_, *context, *(device_contexts_[0]), GetAID().Name(),
|
||||||
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *(device_contexts_[0]),
|
new_device_tensor->GetSize());
|
||||||
GetAID().Name(),
|
|
||||||
copy_input_device_tensors_[input_data->index_]->GetSize());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
MS_LOG(INFO) << GetAID().Name() << " the input position:" << input_data->index_
|
||||||
if (!Copy(copy_input_device_tensors_[input_data->index_].get(), input_data->data_)) {
|
<< " copy from device type: " << input_data->data_->DeviceType()
|
||||||
|
<< ", device format: " << input_data->data_->format()
|
||||||
|
<< " to device type: " << new_device_tensor->DeviceType()
|
||||||
|
<< ", device format: " << new_device_tensor->format();
|
||||||
|
// Copy from the real parameter to formal parameter and insert the device tensor copy store.
|
||||||
|
if (!Copy(new_device_tensor.get(), input_data->data_)) {
|
||||||
std::string error_info = "Copy device tensor failed: " + GetAID().Name();
|
std::string error_info = "Copy device tensor failed: " + GetAID().Name();
|
||||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, error_info);
|
||||||
|
}
|
||||||
|
if (modifiable_ref_input_indexes_.count(input_data->index_) > 0) {
|
||||||
|
DeviceTensorCopyStore::GetInstance().Insert(new_device_tensor.get(), input_data->data_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update by the copy input device tensor.
|
|
||||||
input_device_tensors_[input_data->index_] = copy_input_device_tensors_[input_data->index_].get();
|
|
||||||
memory_free_list_[input_data->index_] = copy_input_device_tensors_[input_data->index_].get();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) {
|
void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) {
|
||||||
|
@ -395,6 +412,10 @@ void KernelActor::PostLaunchKernel(OpContext<DeviceTensor> *const context) {
|
||||||
|
|
||||||
running_dependent_msg_num_ = SizeToInt(input_datas_num_ + input_controls_num_);
|
running_dependent_msg_num_ = SizeToInt(input_datas_num_ + input_controls_num_);
|
||||||
|
|
||||||
|
if ((modifiable_ref_input_indexes_.size() != 0) || (modifiable_ref_output_indexes_.size() != 0)) {
|
||||||
|
RefreshDeviceTensorCopyStore(context);
|
||||||
|
}
|
||||||
|
|
||||||
// The input is invalid and needs to be erased when finish kernel launch.
|
// The input is invalid and needs to be erased when finish kernel launch.
|
||||||
EraseInput(context);
|
EraseInput(context);
|
||||||
|
|
||||||
|
@ -423,6 +444,51 @@ void KernelActor::UpdateOutputAddrSize() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void KernelActor::RefreshDeviceTensorCopyStore(OpContext<DeviceTensor> *const context) {
|
||||||
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
|
for (auto &ref_input_index : modifiable_ref_input_indexes_) {
|
||||||
|
if (ref_input_index >= input_device_tensors_.size()) {
|
||||||
|
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The input index is of range.");
|
||||||
|
}
|
||||||
|
auto &input_device_tensor = input_device_tensors_[ref_input_index];
|
||||||
|
MS_EXCEPTION_IF_NULL(input_device_tensor);
|
||||||
|
auto need_refreshed_device_tensors = DeviceTensorCopyStore::GetInstance().Fetch(input_device_tensor);
|
||||||
|
for (auto &need_refreshed_device_tensor : need_refreshed_device_tensors) {
|
||||||
|
MS_EXCEPTION_IF_NULL(need_refreshed_device_tensor);
|
||||||
|
MS_LOG(INFO) << GetAID().Name() << " the input position:" << ref_input_index
|
||||||
|
<< " refresh from device type: " << input_device_tensor->DeviceType()
|
||||||
|
<< ", device format: " << input_device_tensor->format()
|
||||||
|
<< " to device type: " << need_refreshed_device_tensor->DeviceType()
|
||||||
|
<< ", device format: " << need_refreshed_device_tensor->format();
|
||||||
|
if (!Copy(need_refreshed_device_tensor, input_device_tensor)) {
|
||||||
|
std::string error_info = "Copy input device tensor failed: " + GetAID().Name();
|
||||||
|
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, error_info);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &ref_output_index : modifiable_ref_output_indexes_) {
|
||||||
|
if (ref_output_index >= output_device_tensors_.size()) {
|
||||||
|
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The output index is of range.");
|
||||||
|
}
|
||||||
|
auto &output_device_tensor = input_device_tensors_[ref_output_index];
|
||||||
|
MS_EXCEPTION_IF_NULL(output_device_tensor);
|
||||||
|
auto need_refreshed_device_tensors = DeviceTensorCopyStore::GetInstance().Fetch(output_device_tensor);
|
||||||
|
for (auto &need_refreshed_device_tensor : need_refreshed_device_tensors) {
|
||||||
|
MS_EXCEPTION_IF_NULL(need_refreshed_device_tensor);
|
||||||
|
MS_LOG(INFO) << GetAID().Name() << " the output position:" << ref_output_index
|
||||||
|
<< " refresh from device type: " << output_device_tensor->DeviceType()
|
||||||
|
<< ", device format: " << output_device_tensor->format()
|
||||||
|
<< " to device type: " << need_refreshed_device_tensor->DeviceType()
|
||||||
|
<< ", device format: " << need_refreshed_device_tensor->format();
|
||||||
|
if (!Copy(need_refreshed_device_tensor, output_device_tensor)) {
|
||||||
|
std::string error_info = "Copy output device tensor failed: " + GetAID().Name();
|
||||||
|
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, error_info);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void KernelActor::SendRecorderInfo(OpContext<DeviceTensor> *const context) const {
|
void KernelActor::SendRecorderInfo(OpContext<DeviceTensor> *const context) const {
|
||||||
if (recorder_aid_ != nullptr) {
|
if (recorder_aid_ != nullptr) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_);
|
MS_EXCEPTION_IF_NULL(kernel_);
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_KERNEL_ACTOR_H_
|
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_KERNEL_ACTOR_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
@ -45,13 +46,16 @@ class KernelActor : public DebugAwareActor {
|
||||||
public:
|
public:
|
||||||
KernelActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context,
|
KernelActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context,
|
||||||
const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
|
const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
|
||||||
GraphExecutionStrategy strategy)
|
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
|
||||||
|
const std::set<size_t> &modifiable_ref_output_indexes)
|
||||||
: DebugAwareActor(name, KernelTransformType::kKernelActor, recorder_aid, memory_manager_aid, debug_aid),
|
: DebugAwareActor(name, KernelTransformType::kKernelActor, recorder_aid, memory_manager_aid, debug_aid),
|
||||||
kernel_(kernel),
|
kernel_(kernel),
|
||||||
kernel_info_(nullptr),
|
kernel_info_(nullptr),
|
||||||
is_dynamic_shape_(false),
|
is_dynamic_shape_(false),
|
||||||
real_input_num_(0),
|
real_input_num_(0),
|
||||||
strategy_(strategy) {
|
strategy_(strategy),
|
||||||
|
modifiable_ref_input_indexes_(modifiable_ref_input_indexes),
|
||||||
|
modifiable_ref_output_indexes_(modifiable_ref_output_indexes) {
|
||||||
(void)device_contexts_.emplace_back(device_context);
|
(void)device_contexts_.emplace_back(device_context);
|
||||||
}
|
}
|
||||||
~KernelActor() override = default;
|
~KernelActor() override = default;
|
||||||
|
@ -74,6 +78,8 @@ class KernelActor : public DebugAwareActor {
|
||||||
void OnDebugFinish(OpContext<DeviceTensor> *const context) override;
|
void OnDebugFinish(OpContext<DeviceTensor> *const context) override;
|
||||||
|
|
||||||
const CNodePtr &kernel() const { return kernel_; }
|
const CNodePtr &kernel() const { return kernel_; }
|
||||||
|
const std::set<size_t> &modifiable_ref_input_indexes() const { return modifiable_ref_input_indexes_; }
|
||||||
|
const std::set<size_t> &modifiable_ref_output_indexes() const { return modifiable_ref_output_indexes_; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void Run(OpContext<DeviceTensor> *const context) override;
|
void Run(OpContext<DeviceTensor> *const context) override;
|
||||||
|
@ -86,6 +92,7 @@ class KernelActor : public DebugAwareActor {
|
||||||
// Fetch the device tensor for launch.
|
// Fetch the device tensor for launch.
|
||||||
void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context);
|
void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context);
|
||||||
void FetchOutputDeviceTensor(OpContext<DeviceTensor> *const context);
|
void FetchOutputDeviceTensor(OpContext<DeviceTensor> *const context);
|
||||||
|
// Need copy when the data type or format between real parameters and formal parameters are inconsistent.
|
||||||
void CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *const context);
|
void CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *const context);
|
||||||
// In step mode, push the input tensors which contain valid device address into input_device_tensors_ directly.
|
// In step mode, push the input tensors which contain valid device address into input_device_tensors_ directly.
|
||||||
void PushInputDeviceTensor(const std::vector<TensorPtr> *input_tensors);
|
void PushInputDeviceTensor(const std::vector<TensorPtr> *input_tensors);
|
||||||
|
@ -94,6 +101,8 @@ class KernelActor : public DebugAwareActor {
|
||||||
void PreLaunchKernel(OpContext<DeviceTensor> *const context);
|
void PreLaunchKernel(OpContext<DeviceTensor> *const context);
|
||||||
// The processing after kernel launch: 1.erase input, 2.free memory, 3.send output.
|
// The processing after kernel launch: 1.erase input, 2.free memory, 3.send output.
|
||||||
void PostLaunchKernel(OpContext<DeviceTensor> *const context);
|
void PostLaunchKernel(OpContext<DeviceTensor> *const context);
|
||||||
|
// Back refresh the dynamic device tensor stores that have been triggered copy.
|
||||||
|
void RefreshDeviceTensorCopyStore(OpContext<DeviceTensor> *const context);
|
||||||
|
|
||||||
// The size of output address may be changed in dynamic shape scenario, for example, the output shape of operator
|
// The size of output address may be changed in dynamic shape scenario, for example, the output shape of operator
|
||||||
// 'Unique' will change after PostExecute, the output address size should update.
|
// 'Unique' will change after PostExecute, the output address size should update.
|
||||||
|
@ -116,8 +125,8 @@ class KernelActor : public DebugAwareActor {
|
||||||
std::vector<DeviceTensor *> input_device_tensors_;
|
std::vector<DeviceTensor *> input_device_tensors_;
|
||||||
std::vector<DeviceTensor *> output_device_tensors_;
|
std::vector<DeviceTensor *> output_device_tensors_;
|
||||||
std::vector<DeviceTensor *> workspace_device_tensors_;
|
std::vector<DeviceTensor *> workspace_device_tensors_;
|
||||||
// The received input device type may be different from the device context type in the control flow and host device
|
// The received input device type and format may be different from the formal parameter in the control flow scenarios,
|
||||||
// scenarios, so it needs to be copied from the input device type to the device context type.
|
// so it needs to be copied from the input data to real data that kernel launch needs.
|
||||||
std::vector<DeviceTensorPtr> copy_input_device_tensors_;
|
std::vector<DeviceTensorPtr> copy_input_device_tensors_;
|
||||||
|
|
||||||
// The device tensors for memory alloc and free.
|
// The device tensors for memory alloc and free.
|
||||||
|
@ -131,6 +140,10 @@ class KernelActor : public DebugAwareActor {
|
||||||
// The kernel launch info is fetched by the device tensors.
|
// The kernel launch info is fetched by the device tensors.
|
||||||
KernelLaunchInfo launch_info_;
|
KernelLaunchInfo launch_info_;
|
||||||
|
|
||||||
|
// Record the modifiable ref indexes. Used to refresh the ref data which are modified in the running.
|
||||||
|
std::set<size_t> modifiable_ref_input_indexes_;
|
||||||
|
std::set<size_t> modifiable_ref_output_indexes_;
|
||||||
|
|
||||||
// Cache output data by output index to modify the output data effectively.
|
// Cache output data by output index to modify the output data effectively.
|
||||||
std::vector<std::vector<OpData<DeviceTensor> *>> output_data_by_output_index_;
|
std::vector<std::vector<OpData<DeviceTensor> *>> output_data_by_output_index_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -622,14 +622,6 @@ KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index) {
|
||||||
return {get_item_src_node, *(indexes.begin())};
|
return {get_item_src_node, *(indexes.begin())};
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HasAbstractRef(const AnfNodePtr &node) {
|
|
||||||
if (node == nullptr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto &abs = node->abstract();
|
|
||||||
return (abs != nullptr) && abs->isa<abstract::AbstractRef>();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsCsrNode(const AnfNodePtr &node) {
|
bool IsCsrNode(const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
if (!node->isa<CNode>()) {
|
if (!node->isa<CNode>()) {
|
||||||
|
|
|
@ -94,9 +94,6 @@ struct KernelGraphGroupInfo {
|
||||||
};
|
};
|
||||||
using KernelGraphGroupInfoPtr = std::shared_ptr<KernelGraphGroupInfo>;
|
using KernelGraphGroupInfoPtr = std::shared_ptr<KernelGraphGroupInfo>;
|
||||||
|
|
||||||
// Check whether the parameter is a weight. In the control flow, weight is passed to the subgraph, and in the subgraph,
|
|
||||||
// it is determined whether it is a weight.
|
|
||||||
bool HasAbstractRef(const AnfNodePtr &node);
|
|
||||||
// Check whether the node is a csr node.
|
// Check whether the node is a csr node.
|
||||||
bool IsCsrNode(const AnfNodePtr &node);
|
bool IsCsrNode(const AnfNodePtr &node);
|
||||||
// Get the front node corresponding to the backend node, if the front node is not a parameter node, return the
|
// Get the front node corresponding to the backend node, if the front node is not a parameter node, return the
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_COPY_STORE_H_
|
||||||
|
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_COPY_STORE_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include "utils/hash_map.h"
|
||||||
|
#include "utils/ms_utils.h"
|
||||||
|
#include "runtime/device/device_address.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace runtime {
|
||||||
|
using DeviceTensor = mindspore::device::DeviceAddress;
|
||||||
|
|
||||||
|
// The device tensor mainly includes address ptr, size and reference count,
|
||||||
|
// which represents the basic data structure of kernel launch and transfers between actors.
|
||||||
|
// Some device tensors (such as ref real parameters) need be refreshed in the running,
|
||||||
|
// so they are more suitable for store and can be obtained when they are refreshed copy by actor.
|
||||||
|
class DeviceTensorCopyStore {
|
||||||
|
public:
|
||||||
|
static DeviceTensorCopyStore &GetInstance() {
|
||||||
|
static DeviceTensorCopyStore instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Insert(DeviceTensor *const key, DeviceTensor *const value) {
|
||||||
|
MS_EXCEPTION_IF_NULL(key);
|
||||||
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
(void)copy_device_tensors_[key].insert(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::set<DeviceTensor *> Fetch(DeviceTensor *const key) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(key);
|
||||||
|
const auto &iter = copy_device_tensors_.find(key);
|
||||||
|
if (iter != copy_device_tensors_.end()) {
|
||||||
|
return iter->second;
|
||||||
|
} else {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Clear() { copy_device_tensors_.clear(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
DeviceTensorCopyStore() = default;
|
||||||
|
~DeviceTensorCopyStore() = default;
|
||||||
|
DISABLE_COPY_AND_ASSIGN(DeviceTensorCopyStore);
|
||||||
|
|
||||||
|
// The data storage of device tensor which need be back refreshed dynamically.
|
||||||
|
// It is created and removed dynamically in the running.
|
||||||
|
// Key is the dest device tensor, value is the source device tensors which provide copy data to dest device tensor.
|
||||||
|
mindspore::HashMap<DeviceTensor *, std::set<DeviceTensor *>> copy_device_tensors_;
|
||||||
|
};
|
||||||
|
} // namespace runtime
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_COPY_STORE_H_
|
|
@ -224,6 +224,9 @@ void GraphScheduler::Clear() {
|
||||||
void GraphScheduler::ClearActorData(const ActorSet *actor_set) {
|
void GraphScheduler::ClearActorData(const ActorSet *actor_set) {
|
||||||
MS_EXCEPTION_IF_NULL(actor_set);
|
MS_EXCEPTION_IF_NULL(actor_set);
|
||||||
|
|
||||||
|
// Clear the member of DeviceTensorCopyStore.
|
||||||
|
DeviceTensorCopyStore::GetInstance().Clear();
|
||||||
|
|
||||||
for (auto &super_kernel_actor : actor_set->super_kernel_actors_) {
|
for (auto &super_kernel_actor : actor_set->super_kernel_actors_) {
|
||||||
MS_EXCEPTION_IF_NULL(super_kernel_actor);
|
MS_EXCEPTION_IF_NULL(super_kernel_actor);
|
||||||
super_kernel_actor->memory_free_lists_ = std::queue<std::vector<DeviceTensor *>>();
|
super_kernel_actor->memory_free_lists_ = std::queue<std::vector<DeviceTensor *>>();
|
||||||
|
@ -735,8 +738,11 @@ std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompiler
|
||||||
for (auto &kernel : execution_order) {
|
for (auto &kernel : execution_order) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel);
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
if (IsKernelActor(kernel, graph_compiler_info.strategy_) && (!IsSkippedKernelActor(kernel))) {
|
if (IsKernelActor(kernel, graph_compiler_info.strategy_) && (!IsSkippedKernelActor(kernel))) {
|
||||||
auto kernel_actor = std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context,
|
auto ref_input_indexes = FetchModifiableRefInputIndex(kernel);
|
||||||
memory_manager_aid_, debug_aid_, recorder_aid_, strategy);
|
auto ref_output_indexes = FetchModifiableRefOutputIndex(kernel, graph);
|
||||||
|
auto kernel_actor =
|
||||||
|
std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_,
|
||||||
|
debug_aid_, recorder_aid_, strategy, ref_input_indexes, ref_output_indexes);
|
||||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||||
InsertActor(kernel_actor.get());
|
InsertActor(kernel_actor.get());
|
||||||
(void)kernel_actors.emplace_back(kernel_actor);
|
(void)kernel_actors.emplace_back(kernel_actor);
|
||||||
|
|
Loading…
Reference in New Issue