!29210 unified runtime support host and device in the control flow
Merge pull request !29210 from limingqi107/new_actor_runtime
This commit is contained in:
commit
e26c4824db
|
@ -57,6 +57,7 @@ void EnvironMgr::Clear() {
|
|||
env.second->Clear();
|
||||
}
|
||||
|
||||
env_handles_count_ = 0;
|
||||
envs_.clear();
|
||||
mutex.unlock();
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "mindrt/include/actor/op_actor.h"
|
||||
#include "runtime/framework/actor/actor_common.h"
|
||||
#include "runtime/framework/device_tensor_store.h"
|
||||
#include "runtime/framework/device_tensor_copy_store.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -283,5 +283,51 @@ std::string FetchActorName(KernelTransformType kernel_type, const std::string &a
|
|||
}
|
||||
return actor_name;
|
||||
}
|
||||
|
||||
bool HasAbstractRef(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto &abs = node->abstract();
|
||||
return (abs != nullptr) && abs->isa<abstract::AbstractRef>();
|
||||
}
|
||||
|
||||
std::set<size_t> FetchModifiableRefInputIndex(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Only the auto moand node will modify the input.
|
||||
if (!HasAbstractMonad(cnode)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
std::set<size_t> ref_input_indexes;
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto &input = cnode->inputs().at(i);
|
||||
if (HasAbstractRef(input)) {
|
||||
(void)ref_input_indexes.insert(i - 1);
|
||||
}
|
||||
}
|
||||
return ref_input_indexes;
|
||||
}
|
||||
|
||||
std::set<size_t> FetchModifiableRefOutputIndex(const CNodePtr &cnode, const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::set<size_t> ref_output_indexes;
|
||||
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
session::AnfWithOutIndex output_pair(cnode, i);
|
||||
// Only the ref node will modify the ref input corresponding to the output.
|
||||
if (!graph->IsInRefOutputMap(output_pair)) {
|
||||
continue;
|
||||
}
|
||||
auto input_pair = graph->GetRefCorrespondOutput(output_pair);
|
||||
MS_EXCEPTION_IF_NULL(input_pair.first);
|
||||
if (HasAbstractRef(input_pair.first)) {
|
||||
(void)ref_output_indexes.insert(i);
|
||||
}
|
||||
}
|
||||
return ref_output_indexes;
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include <thread>
|
||||
#include <algorithm>
|
||||
|
@ -214,6 +215,14 @@ KernelTransformType FetchKernelTransformType(const AnfNodePtr &node, const Kerne
|
|||
GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
|
||||
std::string FetchActorName(KernelTransformType kernel_type, const std::string &actor_set_name,
|
||||
const AnfNodePtr &node = nullptr, const KernelGraphPtr &graph = nullptr);
|
||||
|
||||
// Check whether the parameter is a ref parameter.
|
||||
bool HasAbstractRef(const AnfNodePtr &node);
|
||||
|
||||
// Fetch the input indexes which may be modified that exist in the input ref parameter.
|
||||
std::set<size_t> FetchModifiableRefInputIndex(const CNodePtr &node);
|
||||
// Fetch the output indexes which may be modified that exist in the ref node.
|
||||
std::set<size_t> FetchModifiableRefOutputIndex(const CNodePtr &node, const KernelGraphPtr &graph);
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -149,6 +149,19 @@ void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) {
|
|||
}
|
||||
|
||||
DumpAbstractActor(actor, ofs);
|
||||
|
||||
if (actor->modifiable_ref_input_indexes().size() != 0) {
|
||||
ofs << "\t\tmodifiable_ref_input_indexes:" << actor->modifiable_ref_input_indexes().size() << "\n";
|
||||
for (auto &ref_input_index : actor->modifiable_ref_input_indexes()) {
|
||||
ofs << "\t\t\tmodifiable_ref_input_index:" << ref_input_index << "\n ";
|
||||
}
|
||||
}
|
||||
if (actor->modifiable_ref_output_indexes().size() != 0) {
|
||||
ofs << "\t\tmodifiable_ref_output_indexes:" << actor->modifiable_ref_output_indexes().size() << "\n";
|
||||
for (auto &ref_output_index : actor->modifiable_ref_output_indexes()) {
|
||||
ofs << "\t\t\tmodifiable_ref_output_index:" << ref_output_index << "\n ";
|
||||
}
|
||||
}
|
||||
ofs << "\n";
|
||||
}
|
||||
|
||||
|
@ -215,10 +228,10 @@ void DumpControlActor(const ControlActor *actor, std::ofstream &ofs) {
|
|||
DumpAbstractActor(actor, ofs);
|
||||
const auto &local_partials = actor->local_partials();
|
||||
if (local_partials.size() > 0) {
|
||||
ofs << "\t\t\tlocal partial num:" << local_partials.size() << "\n ";
|
||||
ofs << "\t\tlocal partial num:" << local_partials.size() << "\n ";
|
||||
for (const auto &local_partial : local_partials) {
|
||||
MS_EXCEPTION_IF_NULL(local_partial.second->func_graph_);
|
||||
ofs << "\t\t\t\tlocal partial index:" << local_partial.first
|
||||
ofs << "\t\t\tlocal partial index:" << local_partial.first
|
||||
<< "\tgraph:" << local_partial.second->func_graph_->ToString()
|
||||
<< "\tparameter num:" << local_partial.second->device_tensors_.size() << "\n";
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "runtime/framework/actor/control_flow/control_actor.h"
|
||||
#include "runtime/hardware/device_context_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
|
@ -349,12 +350,12 @@ void ControlActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, con
|
|||
|
||||
if (data->GetMutablePtr() == nullptr) {
|
||||
std::string error_info =
|
||||
"The address of the " + std::to_string(formal_parameter_position) + "position formal parameter is nullptr.";
|
||||
"The address of the " + std::to_string(formal_parameter_position) + "position real parameter is nullptr.";
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
if (data->ref_count() != SIZE_MAX) {
|
||||
std::string error_info = "The ref count of the " + std::to_string(formal_parameter_position) +
|
||||
"position formal parameter is wrong:" + std::to_string(data->ref_count());
|
||||
"position real parameter is wrong:" + std::to_string(data->ref_count());
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
|
||||
|
@ -364,19 +365,44 @@ void ControlActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, con
|
|||
if ((device_tensor.get() == data) || (device_tensor->GetMutablePtr() == data->GetMutablePtr())) {
|
||||
continue;
|
||||
}
|
||||
auto real_parameter = device_tensor->GetNodeIndex();
|
||||
MS_EXCEPTION_IF_NULL(real_parameter.first);
|
||||
if ((device_tensor->GetSize() != data->GetSize()) || (device_tensor->format() != data->format()) ||
|
||||
(device_tensor->type_id() != data->type_id())) {
|
||||
auto formal_parameter = device_tensor->GetNodeIndex();
|
||||
MS_EXCEPTION_IF_NULL(formal_parameter.first);
|
||||
if ((device_tensor->GetSize() != data->GetSize()) || (device_tensor->type_id() != data->type_id())) {
|
||||
std::string error_info =
|
||||
"The address of the " + std::to_string(formal_parameter_position) +
|
||||
"position formal parameter can not be set to real parameter:" + real_parameter.first->DebugString();
|
||||
"The formal parameter: " + formal_parameter.first->DebugString() +
|
||||
" position:" + std::to_string(formal_parameter_position) + "can not set from real parameter," +
|
||||
" formal parameter size:" + std::to_string(device_tensor->GetSize()) +
|
||||
" type id:" + std::to_string(device_tensor->type_id()) +
|
||||
", real parameter size:" + std::to_string(data->GetSize()) + " type id:" + std::to_string(data->type_id());
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
|
||||
// Copy from the real parameter to formal parameter and insert the device tensor copy store.
|
||||
if ((device_tensor->format() != data->format()) || (device_tensor->DeviceType() != data->DeviceType())) {
|
||||
MS_LOG(INFO) << "The formal parameter:" << formal_parameter.first->DebugString()
|
||||
<< " input position:" << formal_parameter_position << " need copy from real parameter,"
|
||||
<< " formal parameter format:" << device_tensor->format() << " type:" << device_tensor->DeviceType()
|
||||
<< ", real parameter format:" << data->format() << " type:" << data->DeviceType();
|
||||
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{device_tensor->device_name(), device_tensor->device_id()});
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
if ((device_tensor->GetPtr() == nullptr) &&
|
||||
(!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize()))) {
|
||||
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *device_context,
|
||||
formal_parameter.first->DebugString(), device_tensor->GetSize());
|
||||
}
|
||||
if (!Copy(device_tensor.get(), data)) {
|
||||
std::string error_info = "The formal parameter: " + formal_parameter.first->DebugString() +
|
||||
" position:" + std::to_string(formal_parameter_position) + " copy failed.";
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
output_data->data_ = device_tensor.get();
|
||||
DeviceTensorCopyStore::GetInstance().Insert(device_tensor.get(), data);
|
||||
}
|
||||
|
||||
device_tensor->set_ptr(data->GetMutablePtr());
|
||||
MS_LOG(DEBUG) << "Set the ptr: " << data->GetMutablePtr()
|
||||
<< " for the ref real parameter: " << real_parameter.first->DebugString()
|
||||
<< " for the ref formal parameter: " << formal_parameter.first->DebugString()
|
||||
<< " in the actor: " << GetAID().Name();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -177,12 +177,20 @@ void KernelActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
|
|||
}
|
||||
|
||||
void KernelActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
|
||||
if (strategy_ == GraphExecutionStrategy::kPipeline) {
|
||||
ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &memory_free_list_, device_contexts_[0],
|
||||
context, GetAID());
|
||||
} else {
|
||||
FreeMemory(memory_free_list_, device_contexts_[0]);
|
||||
}
|
||||
|
||||
// Free the address that is the temp store for kernel input copy.
|
||||
for (auto ©_input_device_tensor : copy_input_device_tensors_) {
|
||||
if ((copy_input_device_tensor != nullptr) && (copy_input_device_tensor->GetPtr() != nullptr)) {
|
||||
device_contexts_[0]->FreeMemory(copy_input_device_tensor.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
|
||||
|
@ -248,39 +256,48 @@ void KernelActor::PushInputDeviceTensor(const std::vector<TensorPtr> *input_tens
|
|||
void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
|
||||
OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(input_data);
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
|
||||
if ((input_data->data_ == nullptr) ||
|
||||
(input_data->data_->DeviceType() == device_contexts_[0]->GetDeviceAddressType())) {
|
||||
MS_EXCEPTION_IF_NULL(input_data->data_);
|
||||
const auto &device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel_, input_data->index_, false);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
if ((input_data->data_->DeviceType() == device_tensor->DeviceType()) &&
|
||||
(input_data->data_->format() == device_tensor->format())) {
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Copy from device type: " << input_data->data_->DeviceType()
|
||||
<< " to device type: " << device_contexts_[0]->GetDeviceAddressType() << " in " << GetAID().Name();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
|
||||
if (IntToSize(input_data->index_) >= copy_input_device_tensors_.size()) {
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The input index is of range.");
|
||||
}
|
||||
if (copy_input_device_tensors_[input_data->index_] == nullptr) {
|
||||
copy_input_device_tensors_[input_data->index_] = device_contexts_[0]->CreateDeviceAddress(
|
||||
nullptr, input_data->data_->GetSize(), input_data->data_->format(), input_data->data_->type_id());
|
||||
nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id());
|
||||
}
|
||||
auto &new_device_tensor = copy_input_device_tensors_[input_data->index_];
|
||||
MS_EXCEPTION_IF_NULL(new_device_tensor);
|
||||
// Dynamic shape need update size.
|
||||
copy_input_device_tensors_[input_data->index_]->SetSize(input_data->data_->GetSize());
|
||||
new_device_tensor->SetSize(input_data->data_->GetSize());
|
||||
// Update the input device tensor.
|
||||
input_device_tensors_[input_data->index_] = new_device_tensor.get();
|
||||
|
||||
if (copy_input_device_tensors_[input_data->index_]->GetPtr() == nullptr) {
|
||||
if (!device_contexts_[0]->AllocateMemory(copy_input_device_tensors_[input_data->index_].get(),
|
||||
copy_input_device_tensors_[input_data->index_]->GetSize())) {
|
||||
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *(device_contexts_[0]),
|
||||
GetAID().Name(),
|
||||
copy_input_device_tensors_[input_data->index_]->GetSize());
|
||||
}
|
||||
if ((new_device_tensor->GetPtr() == nullptr) &&
|
||||
(!device_contexts_[0]->AllocateMemory(new_device_tensor.get(), new_device_tensor->GetSize()))) {
|
||||
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(strategy_, *context, *(device_contexts_[0]), GetAID().Name(),
|
||||
new_device_tensor->GetSize());
|
||||
}
|
||||
|
||||
if (!Copy(copy_input_device_tensors_[input_data->index_].get(), input_data->data_)) {
|
||||
MS_LOG(INFO) << GetAID().Name() << " the input position:" << input_data->index_
|
||||
<< " copy from device type: " << input_data->data_->DeviceType()
|
||||
<< ", device format: " << input_data->data_->format()
|
||||
<< " to device type: " << new_device_tensor->DeviceType()
|
||||
<< ", device format: " << new_device_tensor->format();
|
||||
// Copy from the real parameter to formal parameter and insert the device tensor copy store.
|
||||
if (!Copy(new_device_tensor.get(), input_data->data_)) {
|
||||
std::string error_info = "Copy device tensor failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, error_info);
|
||||
}
|
||||
if (modifiable_ref_input_indexes_.count(input_data->index_) > 0) {
|
||||
DeviceTensorCopyStore::GetInstance().Insert(new_device_tensor.get(), input_data->data_);
|
||||
}
|
||||
|
||||
// Update by the copy input device tensor.
|
||||
input_device_tensors_[input_data->index_] = copy_input_device_tensors_[input_data->index_].get();
|
||||
memory_free_list_[input_data->index_] = copy_input_device_tensors_[input_data->index_].get();
|
||||
}
|
||||
|
||||
void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) {
|
||||
|
@ -395,6 +412,10 @@ void KernelActor::PostLaunchKernel(OpContext<DeviceTensor> *const context) {
|
|||
|
||||
running_dependent_msg_num_ = SizeToInt(input_datas_num_ + input_controls_num_);
|
||||
|
||||
if ((modifiable_ref_input_indexes_.size() != 0) || (modifiable_ref_output_indexes_.size() != 0)) {
|
||||
RefreshDeviceTensorCopyStore(context);
|
||||
}
|
||||
|
||||
// The input is invalid and needs to be erased when finish kernel launch.
|
||||
EraseInput(context);
|
||||
|
||||
|
@ -423,6 +444,51 @@ void KernelActor::UpdateOutputAddrSize() {
|
|||
}
|
||||
}
|
||||
|
||||
void KernelActor::RefreshDeviceTensorCopyStore(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
for (auto &ref_input_index : modifiable_ref_input_indexes_) {
|
||||
if (ref_input_index >= input_device_tensors_.size()) {
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The input index is of range.");
|
||||
}
|
||||
auto &input_device_tensor = input_device_tensors_[ref_input_index];
|
||||
MS_EXCEPTION_IF_NULL(input_device_tensor);
|
||||
auto need_refreshed_device_tensors = DeviceTensorCopyStore::GetInstance().Fetch(input_device_tensor);
|
||||
for (auto &need_refreshed_device_tensor : need_refreshed_device_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(need_refreshed_device_tensor);
|
||||
MS_LOG(INFO) << GetAID().Name() << " the input position:" << ref_input_index
|
||||
<< " refresh from device type: " << input_device_tensor->DeviceType()
|
||||
<< ", device format: " << input_device_tensor->format()
|
||||
<< " to device type: " << need_refreshed_device_tensor->DeviceType()
|
||||
<< ", device format: " << need_refreshed_device_tensor->format();
|
||||
if (!Copy(need_refreshed_device_tensor, input_device_tensor)) {
|
||||
std::string error_info = "Copy input device tensor failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, error_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &ref_output_index : modifiable_ref_output_indexes_) {
|
||||
if (ref_output_index >= output_device_tensors_.size()) {
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The output index is of range.");
|
||||
}
|
||||
auto &output_device_tensor = input_device_tensors_[ref_output_index];
|
||||
MS_EXCEPTION_IF_NULL(output_device_tensor);
|
||||
auto need_refreshed_device_tensors = DeviceTensorCopyStore::GetInstance().Fetch(output_device_tensor);
|
||||
for (auto &need_refreshed_device_tensor : need_refreshed_device_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(need_refreshed_device_tensor);
|
||||
MS_LOG(INFO) << GetAID().Name() << " the output position:" << ref_output_index
|
||||
<< " refresh from device type: " << output_device_tensor->DeviceType()
|
||||
<< ", device format: " << output_device_tensor->format()
|
||||
<< " to device type: " << need_refreshed_device_tensor->DeviceType()
|
||||
<< ", device format: " << need_refreshed_device_tensor->format();
|
||||
if (!Copy(need_refreshed_device_tensor, output_device_tensor)) {
|
||||
std::string error_info = "Copy output device tensor failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, error_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void KernelActor::SendRecorderInfo(OpContext<DeviceTensor> *const context) const {
|
||||
if (recorder_aid_ != nullptr) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_);
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_KERNEL_ACTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
@ -45,13 +46,16 @@ class KernelActor : public DebugAwareActor {
|
|||
public:
|
||||
KernelActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context,
|
||||
const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
|
||||
GraphExecutionStrategy strategy)
|
||||
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
|
||||
const std::set<size_t> &modifiable_ref_output_indexes)
|
||||
: DebugAwareActor(name, KernelTransformType::kKernelActor, recorder_aid, memory_manager_aid, debug_aid),
|
||||
kernel_(kernel),
|
||||
kernel_info_(nullptr),
|
||||
is_dynamic_shape_(false),
|
||||
real_input_num_(0),
|
||||
strategy_(strategy) {
|
||||
strategy_(strategy),
|
||||
modifiable_ref_input_indexes_(modifiable_ref_input_indexes),
|
||||
modifiable_ref_output_indexes_(modifiable_ref_output_indexes) {
|
||||
(void)device_contexts_.emplace_back(device_context);
|
||||
}
|
||||
~KernelActor() override = default;
|
||||
|
@ -74,6 +78,8 @@ class KernelActor : public DebugAwareActor {
|
|||
void OnDebugFinish(OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
const CNodePtr &kernel() const { return kernel_; }
|
||||
const std::set<size_t> &modifiable_ref_input_indexes() const { return modifiable_ref_input_indexes_; }
|
||||
const std::set<size_t> &modifiable_ref_output_indexes() const { return modifiable_ref_output_indexes_; }
|
||||
|
||||
protected:
|
||||
void Run(OpContext<DeviceTensor> *const context) override;
|
||||
|
@ -86,6 +92,7 @@ class KernelActor : public DebugAwareActor {
|
|||
// Fetch the device tensor for launch.
|
||||
void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context);
|
||||
void FetchOutputDeviceTensor(OpContext<DeviceTensor> *const context);
|
||||
// Need copy when the data type or format between real parameters and formal parameters are inconsistent.
|
||||
void CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *const context);
|
||||
// In step mode, push the input tensors which contain valid device address into input_device_tensors_ directly.
|
||||
void PushInputDeviceTensor(const std::vector<TensorPtr> *input_tensors);
|
||||
|
@ -94,6 +101,8 @@ class KernelActor : public DebugAwareActor {
|
|||
void PreLaunchKernel(OpContext<DeviceTensor> *const context);
|
||||
// The processing after kernel launch: 1.erase input, 2.free memory, 3.send output.
|
||||
void PostLaunchKernel(OpContext<DeviceTensor> *const context);
|
||||
// Back refresh the dynamic device tensor stores that have been triggered copy.
|
||||
void RefreshDeviceTensorCopyStore(OpContext<DeviceTensor> *const context);
|
||||
|
||||
// The size of output address may be changed in dynamic shape scenario, for example, the output shape of operator
|
||||
// 'Unique' will change after PostExecute, the output address size should update.
|
||||
|
@ -116,8 +125,8 @@ class KernelActor : public DebugAwareActor {
|
|||
std::vector<DeviceTensor *> input_device_tensors_;
|
||||
std::vector<DeviceTensor *> output_device_tensors_;
|
||||
std::vector<DeviceTensor *> workspace_device_tensors_;
|
||||
// The received input device type may be different from the device context type in the control flow and host device
|
||||
// scenarios, so it needs to be copied from the input device type to the device context type.
|
||||
// The received input device type and format may be different from the formal parameter in the control flow scenarios,
|
||||
// so it needs to be copied from the input data to real data that kernel launch needs.
|
||||
std::vector<DeviceTensorPtr> copy_input_device_tensors_;
|
||||
|
||||
// The device tensors for memory alloc and free.
|
||||
|
@ -131,6 +140,10 @@ class KernelActor : public DebugAwareActor {
|
|||
// The kernel launch info is fetched by the device tensors.
|
||||
KernelLaunchInfo launch_info_;
|
||||
|
||||
// Record the modifiable ref indexes. Used to refresh the ref data which are modified in the running.
|
||||
std::set<size_t> modifiable_ref_input_indexes_;
|
||||
std::set<size_t> modifiable_ref_output_indexes_;
|
||||
|
||||
// Cache output data by output index to modify the output data effectively.
|
||||
std::vector<std::vector<OpData<DeviceTensor> *>> output_data_by_output_index_;
|
||||
};
|
||||
|
|
|
@ -622,14 +622,6 @@ KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index) {
|
|||
return {get_item_src_node, *(indexes.begin())};
|
||||
}
|
||||
|
||||
bool HasAbstractRef(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto &abs = node->abstract();
|
||||
return (abs != nullptr) && abs->isa<abstract::AbstractRef>();
|
||||
}
|
||||
|
||||
bool IsCsrNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
|
|
|
@ -94,9 +94,6 @@ struct KernelGraphGroupInfo {
|
|||
};
|
||||
using KernelGraphGroupInfoPtr = std::shared_ptr<KernelGraphGroupInfo>;
|
||||
|
||||
// Check whether the parameter is a weight. In the control flow, weight is passed to the subgraph, and in the subgraph,
|
||||
// it is determined whether it is a weight.
|
||||
bool HasAbstractRef(const AnfNodePtr &node);
|
||||
// Check whether the node is a csr node.
|
||||
bool IsCsrNode(const AnfNodePtr &node);
|
||||
// Get the front node corresponding to the backend node, if the front node is not a parameter node, return the
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_COPY_STORE_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_COPY_STORE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include "utils/hash_map.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "runtime/device/device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using DeviceTensor = mindspore::device::DeviceAddress;
|
||||
|
||||
// The device tensor mainly includes address ptr, size and reference count,
|
||||
// which represents the basic data structure of kernel launch and transfers between actors.
|
||||
// Some device tensors (such as ref real parameters) need be refreshed in the running,
|
||||
// so they are more suitable for store and can be obtained when they are refreshed copy by actor.
|
||||
class DeviceTensorCopyStore {
|
||||
public:
|
||||
static DeviceTensorCopyStore &GetInstance() {
|
||||
static DeviceTensorCopyStore instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void Insert(DeviceTensor *const key, DeviceTensor *const value) {
|
||||
MS_EXCEPTION_IF_NULL(key);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
(void)copy_device_tensors_[key].insert(value);
|
||||
}
|
||||
|
||||
std::set<DeviceTensor *> Fetch(DeviceTensor *const key) const {
|
||||
MS_EXCEPTION_IF_NULL(key);
|
||||
const auto &iter = copy_device_tensors_.find(key);
|
||||
if (iter != copy_device_tensors_.end()) {
|
||||
return iter->second;
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
void Clear() { copy_device_tensors_.clear(); }
|
||||
|
||||
private:
|
||||
DeviceTensorCopyStore() = default;
|
||||
~DeviceTensorCopyStore() = default;
|
||||
DISABLE_COPY_AND_ASSIGN(DeviceTensorCopyStore);
|
||||
|
||||
// The data storage of device tensor which need be back refreshed dynamically.
|
||||
// It is created and removed dynamically in the running.
|
||||
// Key is the dest device tensor, value is the source device tensors which provide copy data to dest device tensor.
|
||||
mindspore::HashMap<DeviceTensor *, std::set<DeviceTensor *>> copy_device_tensors_;
|
||||
};
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_COPY_STORE_H_
|
|
@ -224,6 +224,9 @@ void GraphScheduler::Clear() {
|
|||
void GraphScheduler::ClearActorData(const ActorSet *actor_set) {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
|
||||
// Clear the member of DeviceTensorCopyStore.
|
||||
DeviceTensorCopyStore::GetInstance().Clear();
|
||||
|
||||
for (auto &super_kernel_actor : actor_set->super_kernel_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(super_kernel_actor);
|
||||
super_kernel_actor->memory_free_lists_ = std::queue<std::vector<DeviceTensor *>>();
|
||||
|
@ -735,8 +738,11 @@ std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompiler
|
|||
for (auto &kernel : execution_order) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
if (IsKernelActor(kernel, graph_compiler_info.strategy_) && (!IsSkippedKernelActor(kernel))) {
|
||||
auto kernel_actor = std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context,
|
||||
memory_manager_aid_, debug_aid_, recorder_aid_, strategy);
|
||||
auto ref_input_indexes = FetchModifiableRefInputIndex(kernel);
|
||||
auto ref_output_indexes = FetchModifiableRefOutputIndex(kernel, graph);
|
||||
auto kernel_actor =
|
||||
std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_,
|
||||
debug_aid_, recorder_aid_, strategy, ref_input_indexes, ref_output_indexes);
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
InsertActor(kernel_actor.get());
|
||||
(void)kernel_actors.emplace_back(kernel_actor);
|
||||
|
|
Loading…
Reference in New Issue