!29210 unified runtime support host and device in the control flow

Merge pull request !29210 from limingqi107/new_actor_runtime
This commit is contained in:
i-robot 2022-01-18 01:04:20 +00:00 committed by Gitee
commit e26c4824db
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
12 changed files with 292 additions and 50 deletions

View File

@ -57,6 +57,7 @@ void EnvironMgr::Clear() {
env.second->Clear();
}
env_handles_count_ = 0;
envs_.clear();
mutex.unlock();
}

View File

@ -25,6 +25,7 @@
#include "mindrt/include/actor/op_actor.h"
#include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/device_tensor_store.h"
#include "runtime/framework/device_tensor_copy_store.h"
#include "runtime/hardware/device_context.h"
namespace mindspore {

View File

@ -283,5 +283,51 @@ std::string FetchActorName(KernelTransformType kernel_type, const std::string &a
}
return actor_name;
}
bool HasAbstractRef(const AnfNodePtr &node) {
if (node == nullptr) {
return false;
}
auto &abs = node->abstract();
return (abs != nullptr) && abs->isa<abstract::AbstractRef>();
}
std::set<size_t> FetchModifiableRefInputIndex(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
// Only the auto moand node will modify the input.
if (!HasAbstractMonad(cnode)) {
return {};
}
std::set<size_t> ref_input_indexes;
for (size_t i = 1; i < cnode->size(); ++i) {
auto &input = cnode->inputs().at(i);
if (HasAbstractRef(input)) {
(void)ref_input_indexes.insert(i - 1);
}
}
return ref_input_indexes;
}
std::set<size_t> FetchModifiableRefOutputIndex(const CNodePtr &cnode, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(graph);
std::set<size_t> ref_output_indexes;
auto output_num = AnfAlgo::GetOutputTensorNum(cnode);
for (size_t i = 0; i < output_num; ++i) {
session::AnfWithOutIndex output_pair(cnode, i);
// Only the ref node will modify the ref input corresponding to the output.
if (!graph->IsInRefOutputMap(output_pair)) {
continue;
}
auto input_pair = graph->GetRefCorrespondOutput(output_pair);
MS_EXCEPTION_IF_NULL(input_pair.first);
if (HasAbstractRef(input_pair.first)) {
(void)ref_output_indexes.insert(i);
}
}
return ref_output_indexes;
}
} // namespace runtime
} // namespace mindspore

View File

@ -19,6 +19,7 @@
#include <string>
#include <vector>
#include <set>
#include <utility>
#include <thread>
#include <algorithm>
@ -214,6 +215,14 @@ KernelTransformType FetchKernelTransformType(const AnfNodePtr &node, const Kerne
GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
std::string FetchActorName(KernelTransformType kernel_type, const std::string &actor_set_name,
const AnfNodePtr &node = nullptr, const KernelGraphPtr &graph = nullptr);
// Check whether the parameter is a ref parameter.
bool HasAbstractRef(const AnfNodePtr &node);
// Fetch the input indexes which may be modified that exist in the input ref parameter.
std::set<size_t> FetchModifiableRefInputIndex(const CNodePtr &node);
// Fetch the output indexes which may be modified that exist in the ref node.
std::set<size_t> FetchModifiableRefOutputIndex(const CNodePtr &node, const KernelGraphPtr &graph);
} // namespace runtime
} // namespace mindspore

View File

@ -149,6 +149,19 @@ void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) {
}
DumpAbstractActor(actor, ofs);
if (actor->modifiable_ref_input_indexes().size() != 0) {
ofs << "\t\tmodifiable_ref_input_indexes:" << actor->modifiable_ref_input_indexes().size() << "\n";
for (auto &ref_input_index : actor->modifiable_ref_input_indexes()) {
ofs << "\t\t\tmodifiable_ref_input_index:" << ref_input_index << "\n ";
}
}
if (actor->modifiable_ref_output_indexes().size() != 0) {
ofs << "\t\tmodifiable_ref_output_indexes:" << actor->modifiable_ref_output_indexes().size() << "\n";
for (auto &ref_output_index : actor->modifiable_ref_output_indexes()) {
ofs << "\t\t\tmodifiable_ref_output_index:" << ref_output_index << "\n ";
}
}
ofs << "\n";
}
@ -215,10 +228,10 @@ void DumpControlActor(const ControlActor *actor, std::ofstream &ofs) {
DumpAbstractActor(actor, ofs);
const auto &local_partials = actor->local_partials();
if (local_partials.size() > 0) {
ofs << "\t\t\tlocal partial num:" << local_partials.size() << "\n ";
ofs << "\t\tlocal partial num:" << local_partials.size() << "\n ";
for (const auto &local_partial : local_partials) {
MS_EXCEPTION_IF_NULL(local_partial.second->func_graph_);
ofs << "\t\t\t\tlocal partial index:" << local_partial.first
ofs << "\t\t\tlocal partial index:" << local_partial.first
<< "\tgraph:" << local_partial.second->func_graph_->ToString()
<< "\tparameter num:" << local_partial.second->device_tensors_.size() << "\n";
}

View File

@ -15,6 +15,7 @@
*/
#include "runtime/framework/actor/control_flow/control_actor.h"
#include "runtime/hardware/device_context_manager.h"
namespace mindspore {
namespace runtime {
@ -349,12 +350,12 @@ void ControlActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, con
if (data->GetMutablePtr() == nullptr) {
std::string error_info =
"The address of the " + std::to_string(formal_parameter_position) + "position formal parameter is nullptr.";
"The address of the " + std::to_string(formal_parameter_position) + "position real parameter is nullptr.";
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
if (data->ref_count() != SIZE_MAX) {
std::string error_info = "The ref count of the " + std::to_string(formal_parameter_position) +
"position formal parameter is wrong:" + std::to_string(data->ref_count());
"position real parameter is wrong:" + std::to_string(data->ref_count());
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
@ -364,19 +365,44 @@ void ControlActor::UpdateOutputData(OpData<DeviceTensor> *const output_data, con
if ((device_tensor.get() == data) || (device_tensor->GetMutablePtr() == data->GetMutablePtr())) {
continue;
}
auto real_parameter = device_tensor->GetNodeIndex();
MS_EXCEPTION_IF_NULL(real_parameter.first);
if ((device_tensor->GetSize() != data->GetSize()) || (device_tensor->format() != data->format()) ||
(device_tensor->type_id() != data->type_id())) {
auto formal_parameter = device_tensor->GetNodeIndex();
MS_EXCEPTION_IF_NULL(formal_parameter.first);
if ((device_tensor->GetSize() != data->GetSize()) || (device_tensor->type_id() != data->type_id())) {
std::string error_info =
"The address of the " + std::to_string(formal_parameter_position) +
"position formal parameter can not be set to real parameter:" + real_parameter.first->DebugString();
"The formal parameter: " + formal_parameter.first->DebugString() +
" position:" + std::to_string(formal_parameter_position) + "can not set from real parameter," +
" formal parameter size:" + std::to_string(device_tensor->GetSize()) +
" type id:" + std::to_string(device_tensor->type_id()) +
", real parameter size:" + std::to_string(data->GetSize()) + " type id:" + std::to_string(data->type_id());
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
// Copy from the real parameter to formal parameter and insert the device tensor copy store.
if ((device_tensor->format() != data->format()) || (device_tensor->DeviceType() != data->DeviceType())) {
MS_LOG(INFO) << "The formal parameter:" << formal_parameter.first->DebugString()
<< " input position:" << formal_parameter_position << " need copy from real parameter,"
<< " formal parameter format:" << device_tensor->format() << " type:" << device_tensor->DeviceType()
<< ", real parameter format:" << data->format() << " type:" << data->DeviceType();
const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{device_tensor->device_name(), device_tensor->device_id()});
MS_EXCEPTION_IF_NULL(device_context);
if ((device_tensor->GetPtr() == nullptr) &&
(!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize()))) {
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *device_context,
formal_parameter.first->DebugString(), device_tensor->GetSize());
}
if (!Copy(device_tensor.get(), data)) {
std::string error_info = "The formal parameter: " + formal_parameter.first->DebugString() +
" position:" + std::to_string(formal_parameter_position) + " copy failed.";
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
output_data->data_ = device_tensor.get();
DeviceTensorCopyStore::GetInstance().Insert(device_tensor.get(), data);
}
device_tensor->set_ptr(data->GetMutablePtr());
MS_LOG(DEBUG) << "Set the ptr: " << data->GetMutablePtr()
<< " for the ref real parameter: " << real_parameter.first->DebugString()
<< " for the ref formal parameter: " << formal_parameter.first->DebugString()
<< " in the actor: " << GetAID().Name();
}
}

View File

@ -177,12 +177,20 @@ void KernelActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
}
void KernelActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
if (strategy_ == GraphExecutionStrategy::kPipeline) {
ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &memory_free_list_, device_contexts_[0],
context, GetAID());
} else {
FreeMemory(memory_free_list_, device_contexts_[0]);
}
// Free the address that is the temp store for kernel input copy.
for (auto &copy_input_device_tensor : copy_input_device_tensors_) {
if ((copy_input_device_tensor != nullptr) && (copy_input_device_tensor->GetPtr() != nullptr)) {
device_contexts_[0]->FreeMemory(copy_input_device_tensor.get());
}
}
}
void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
@ -248,39 +256,48 @@ void KernelActor::PushInputDeviceTensor(const std::vector<TensorPtr> *input_tens
void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(input_data);
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
if ((input_data->data_ == nullptr) ||
(input_data->data_->DeviceType() == device_contexts_[0]->GetDeviceAddressType())) {
MS_EXCEPTION_IF_NULL(input_data->data_);
const auto &device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel_, input_data->index_, false);
MS_EXCEPTION_IF_NULL(device_tensor);
if ((input_data->data_->DeviceType() == device_tensor->DeviceType()) &&
(input_data->data_->format() == device_tensor->format())) {
return;
}
MS_LOG(DEBUG) << "Copy from device type: " << input_data->data_->DeviceType()
<< " to device type: " << device_contexts_[0]->GetDeviceAddressType() << " in " << GetAID().Name();
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
if (IntToSize(input_data->index_) >= copy_input_device_tensors_.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The input index is of range.");
}
if (copy_input_device_tensors_[input_data->index_] == nullptr) {
copy_input_device_tensors_[input_data->index_] = device_contexts_[0]->CreateDeviceAddress(
nullptr, input_data->data_->GetSize(), input_data->data_->format(), input_data->data_->type_id());
nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id());
}
auto &new_device_tensor = copy_input_device_tensors_[input_data->index_];
MS_EXCEPTION_IF_NULL(new_device_tensor);
// Dynamic shape need update size.
copy_input_device_tensors_[input_data->index_]->SetSize(input_data->data_->GetSize());
new_device_tensor->SetSize(input_data->data_->GetSize());
// Update the input device tensor.
input_device_tensors_[input_data->index_] = new_device_tensor.get();
if (copy_input_device_tensors_[input_data->index_]->GetPtr() == nullptr) {
if (!device_contexts_[0]->AllocateMemory(copy_input_device_tensors_[input_data->index_].get(),
copy_input_device_tensors_[input_data->index_]->GetSize())) {
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context, *(device_contexts_[0]),
GetAID().Name(),
copy_input_device_tensors_[input_data->index_]->GetSize());
}
if ((new_device_tensor->GetPtr() == nullptr) &&
(!device_contexts_[0]->AllocateMemory(new_device_tensor.get(), new_device_tensor->GetSize()))) {
SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(strategy_, *context, *(device_contexts_[0]), GetAID().Name(),
new_device_tensor->GetSize());
}
if (!Copy(copy_input_device_tensors_[input_data->index_].get(), input_data->data_)) {
MS_LOG(INFO) << GetAID().Name() << " the input position:" << input_data->index_
<< " copy from device type: " << input_data->data_->DeviceType()
<< ", device format: " << input_data->data_->format()
<< " to device type: " << new_device_tensor->DeviceType()
<< ", device format: " << new_device_tensor->format();
// Copy from the real parameter to formal parameter and insert the device tensor copy store.
if (!Copy(new_device_tensor.get(), input_data->data_)) {
std::string error_info = "Copy device tensor failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, error_info);
}
if (modifiable_ref_input_indexes_.count(input_data->index_) > 0) {
DeviceTensorCopyStore::GetInstance().Insert(new_device_tensor.get(), input_data->data_);
}
// Update by the copy input device tensor.
input_device_tensors_[input_data->index_] = copy_input_device_tensors_[input_data->index_].get();
memory_free_list_[input_data->index_] = copy_input_device_tensors_[input_data->index_].get();
}
void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) {
@ -395,6 +412,10 @@ void KernelActor::PostLaunchKernel(OpContext<DeviceTensor> *const context) {
running_dependent_msg_num_ = SizeToInt(input_datas_num_ + input_controls_num_);
if ((modifiable_ref_input_indexes_.size() != 0) || (modifiable_ref_output_indexes_.size() != 0)) {
RefreshDeviceTensorCopyStore(context);
}
// The input is invalid and needs to be erased when finish kernel launch.
EraseInput(context);
@ -423,6 +444,51 @@ void KernelActor::UpdateOutputAddrSize() {
}
}
void KernelActor::RefreshDeviceTensorCopyStore(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
for (auto &ref_input_index : modifiable_ref_input_indexes_) {
if (ref_input_index >= input_device_tensors_.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The input index is of range.");
}
auto &input_device_tensor = input_device_tensors_[ref_input_index];
MS_EXCEPTION_IF_NULL(input_device_tensor);
auto need_refreshed_device_tensors = DeviceTensorCopyStore::GetInstance().Fetch(input_device_tensor);
for (auto &need_refreshed_device_tensor : need_refreshed_device_tensors) {
MS_EXCEPTION_IF_NULL(need_refreshed_device_tensor);
MS_LOG(INFO) << GetAID().Name() << " the input position:" << ref_input_index
<< " refresh from device type: " << input_device_tensor->DeviceType()
<< ", device format: " << input_device_tensor->format()
<< " to device type: " << need_refreshed_device_tensor->DeviceType()
<< ", device format: " << need_refreshed_device_tensor->format();
if (!Copy(need_refreshed_device_tensor, input_device_tensor)) {
std::string error_info = "Copy input device tensor failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, error_info);
}
}
}
for (auto &ref_output_index : modifiable_ref_output_indexes_) {
if (ref_output_index >= output_device_tensors_.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The output index is of range.");
}
auto &output_device_tensor = input_device_tensors_[ref_output_index];
MS_EXCEPTION_IF_NULL(output_device_tensor);
auto need_refreshed_device_tensors = DeviceTensorCopyStore::GetInstance().Fetch(output_device_tensor);
for (auto &need_refreshed_device_tensor : need_refreshed_device_tensors) {
MS_EXCEPTION_IF_NULL(need_refreshed_device_tensor);
MS_LOG(INFO) << GetAID().Name() << " the output position:" << ref_output_index
<< " refresh from device type: " << output_device_tensor->DeviceType()
<< ", device format: " << output_device_tensor->format()
<< " to device type: " << need_refreshed_device_tensor->DeviceType()
<< ", device format: " << need_refreshed_device_tensor->format();
if (!Copy(need_refreshed_device_tensor, output_device_tensor)) {
std::string error_info = "Copy output device tensor failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, error_info);
}
}
}
}
void KernelActor::SendRecorderInfo(OpContext<DeviceTensor> *const context) const {
if (recorder_aid_ != nullptr) {
MS_EXCEPTION_IF_NULL(kernel_);

View File

@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_KERNEL_ACTOR_H_
#include <vector>
#include <set>
#include <string>
#include <memory>
#include <utility>
@ -45,13 +46,16 @@ class KernelActor : public DebugAwareActor {
public:
KernelActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context,
const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
GraphExecutionStrategy strategy)
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
const std::set<size_t> &modifiable_ref_output_indexes)
: DebugAwareActor(name, KernelTransformType::kKernelActor, recorder_aid, memory_manager_aid, debug_aid),
kernel_(kernel),
kernel_info_(nullptr),
is_dynamic_shape_(false),
real_input_num_(0),
strategy_(strategy) {
strategy_(strategy),
modifiable_ref_input_indexes_(modifiable_ref_input_indexes),
modifiable_ref_output_indexes_(modifiable_ref_output_indexes) {
(void)device_contexts_.emplace_back(device_context);
}
~KernelActor() override = default;
@ -74,6 +78,8 @@ class KernelActor : public DebugAwareActor {
void OnDebugFinish(OpContext<DeviceTensor> *const context) override;
const CNodePtr &kernel() const { return kernel_; }
const std::set<size_t> &modifiable_ref_input_indexes() const { return modifiable_ref_input_indexes_; }
const std::set<size_t> &modifiable_ref_output_indexes() const { return modifiable_ref_output_indexes_; }
protected:
void Run(OpContext<DeviceTensor> *const context) override;
@ -86,6 +92,7 @@ class KernelActor : public DebugAwareActor {
// Fetch the device tensor for launch.
void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context);
void FetchOutputDeviceTensor(OpContext<DeviceTensor> *const context);
// Need copy when the data type or format between real parameters and formal parameters are inconsistent.
void CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *const context);
// In step mode, push the input tensors which contain valid device address into input_device_tensors_ directly.
void PushInputDeviceTensor(const std::vector<TensorPtr> *input_tensors);
@ -94,6 +101,8 @@ class KernelActor : public DebugAwareActor {
void PreLaunchKernel(OpContext<DeviceTensor> *const context);
// The processing after kernel launch: 1.erase input, 2.free memory, 3.send output.
void PostLaunchKernel(OpContext<DeviceTensor> *const context);
// Back refresh the dynamic device tensor stores that have been triggered copy.
void RefreshDeviceTensorCopyStore(OpContext<DeviceTensor> *const context);
// The size of output address may be changed in dynamic shape scenario, for example, the output shape of operator
// 'Unique' will change after PostExecute, the output address size should update.
@ -116,8 +125,8 @@ class KernelActor : public DebugAwareActor {
std::vector<DeviceTensor *> input_device_tensors_;
std::vector<DeviceTensor *> output_device_tensors_;
std::vector<DeviceTensor *> workspace_device_tensors_;
// The received input device type may be different from the device context type in the control flow and host device
// scenarios, so it needs to be copied from the input device type to the device context type.
// The received input device type and format may be different from the formal parameter in the control flow scenarios,
// so it needs to be copied from the input data to real data that kernel launch needs.
std::vector<DeviceTensorPtr> copy_input_device_tensors_;
// The device tensors for memory alloc and free.
@ -131,6 +140,10 @@ class KernelActor : public DebugAwareActor {
// The kernel launch info is fetched by the device tensors.
KernelLaunchInfo launch_info_;
// Record the modifiable ref indexes. Used to refresh the ref data which are modified in the running.
std::set<size_t> modifiable_ref_input_indexes_;
std::set<size_t> modifiable_ref_output_indexes_;
// Cache output data by output index to modify the output data effectively.
std::vector<std::vector<OpData<DeviceTensor> *>> output_data_by_output_index_;
};

View File

@ -622,14 +622,6 @@ KernelWithIndex FetchRealNodeByGetItem(const KernelWithIndex &node_with_index) {
return {get_item_src_node, *(indexes.begin())};
}
bool HasAbstractRef(const AnfNodePtr &node) {
if (node == nullptr) {
return false;
}
auto &abs = node->abstract();
return (abs != nullptr) && abs->isa<abstract::AbstractRef>();
}
bool IsCsrNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {

View File

@ -94,9 +94,6 @@ struct KernelGraphGroupInfo {
};
using KernelGraphGroupInfoPtr = std::shared_ptr<KernelGraphGroupInfo>;
// Check whether the parameter is a weight. In the control flow, weight is passed to the subgraph, and in the subgraph,
// it is determined whether it is a weight.
bool HasAbstractRef(const AnfNodePtr &node);
// Check whether the node is a csr node.
bool IsCsrNode(const AnfNodePtr &node);
// Get the front node corresponding to the backend node, if the front node is not a parameter node, return the

View File

@ -0,0 +1,72 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_COPY_STORE_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_COPY_STORE_H_
#include <memory>
#include <set>
#include "utils/hash_map.h"
#include "utils/ms_utils.h"
#include "runtime/device/device_address.h"
namespace mindspore {
namespace runtime {
using DeviceTensor = mindspore::device::DeviceAddress;
// The device tensor mainly includes address ptr, size and reference count,
// which represents the basic data structure of kernel launch and transfers between actors.
// Some device tensors (such as ref real parameters) need be refreshed in the running,
// so they are more suitable for store and can be obtained when they are refreshed copy by actor.
class DeviceTensorCopyStore {
public:
static DeviceTensorCopyStore &GetInstance() {
static DeviceTensorCopyStore instance;
return instance;
}
void Insert(DeviceTensor *const key, DeviceTensor *const value) {
MS_EXCEPTION_IF_NULL(key);
MS_EXCEPTION_IF_NULL(value);
(void)copy_device_tensors_[key].insert(value);
}
std::set<DeviceTensor *> Fetch(DeviceTensor *const key) const {
MS_EXCEPTION_IF_NULL(key);
const auto &iter = copy_device_tensors_.find(key);
if (iter != copy_device_tensors_.end()) {
return iter->second;
} else {
return {};
}
}
void Clear() { copy_device_tensors_.clear(); }
private:
DeviceTensorCopyStore() = default;
~DeviceTensorCopyStore() = default;
DISABLE_COPY_AND_ASSIGN(DeviceTensorCopyStore);
// The data storage of device tensor which need be back refreshed dynamically.
// It is created and removed dynamically in the running.
// Key is the dest device tensor, value is the source device tensors which provide copy data to dest device tensor.
mindspore::HashMap<DeviceTensor *, std::set<DeviceTensor *>> copy_device_tensors_;
};
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_COPY_STORE_H_

View File

@ -224,6 +224,9 @@ void GraphScheduler::Clear() {
void GraphScheduler::ClearActorData(const ActorSet *actor_set) {
MS_EXCEPTION_IF_NULL(actor_set);
// Clear the member of DeviceTensorCopyStore.
DeviceTensorCopyStore::GetInstance().Clear();
for (auto &super_kernel_actor : actor_set->super_kernel_actors_) {
MS_EXCEPTION_IF_NULL(super_kernel_actor);
super_kernel_actor->memory_free_lists_ = std::queue<std::vector<DeviceTensor *>>();
@ -735,8 +738,11 @@ std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompiler
for (auto &kernel : execution_order) {
MS_EXCEPTION_IF_NULL(kernel);
if (IsKernelActor(kernel, graph_compiler_info.strategy_) && (!IsSkippedKernelActor(kernel))) {
auto kernel_actor = std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context,
memory_manager_aid_, debug_aid_, recorder_aid_, strategy);
auto ref_input_indexes = FetchModifiableRefInputIndex(kernel);
auto ref_output_indexes = FetchModifiableRefOutputIndex(kernel, graph);
auto kernel_actor =
std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_,
debug_aid_, recorder_aid_, strategy, ref_input_indexes, ref_output_indexes);
MS_EXCEPTION_IF_NULL(kernel_actor);
InsertActor(kernel_actor.get());
(void)kernel_actors.emplace_back(kernel_actor);