unified runtime add abstract actor and optimize code

This commit is contained in:
limingqi107 2021-08-12 20:35:43 +08:00
parent 0f74d2a704
commit 15e6ace23b
15 changed files with 375 additions and 391 deletions

View File

@ -0,0 +1,69 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/framework/actor/abstract_actor.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace runtime {
bool AbstractActor::CheckRunningCondition(OpContext<DeviceTensor> *const context) const {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
if (data_iter == input_op_datas_.end()) {
return false;
}
if (data_iter->second.size() != input_datas_num_) {
return false;
}
}
if (input_controls_num_ != 0) {
const auto &control_iter = input_op_controls_.find(context->sequential_num_);
if (control_iter == input_op_controls_.end()) {
return false;
}
if (control_iter->second.size() != input_controls_num_) {
return false;
}
}
return true;
}
void AbstractActor::EraseInput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
auto ret = input_op_datas_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input data failed: " + GetAID().Name();
// The sequential num may be invalid, can't set the promise value of context.
MS_LOG(ERROR) << error_info << ", sequential_num: " << context->sequential_num_;
return;
}
}
if (input_controls_num_ != 0) {
auto ret = input_op_controls_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input controls failed: " + GetAID().Name();
// The sequential num may be invalid, can't set the promise value of context.
MS_LOG(ERROR) << error_info << ", sequential_num: " << context->sequential_num_;
return;
}
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -0,0 +1,85 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_ABSTRACT_ACTOR_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_ABSTRACT_ACTOR_H_
#include <vector>
#include <string>
#include <memory>
#include <utility>
#include "mindrt/include/actor/op_actor.h"
#include "runtime/framework/device_tensor_store.h"
#include "runtime/hardware/device_context.h"
namespace mindspore {
namespace runtime {
using mindspore::device::DeviceContext;
const size_t kDeviceContextsNumOne = 1;
const size_t kDeviceContextsNumTwo = 2;
// The abstract common attributes of actors. The actor inheritance relationship: OpActor --> AbstractActor -->
// MemoryAwareActor --> DebugAwareActor --> KernelActor/DataSourceActor/CopyActor/LoopCountActor/OutputActor.
class AbstractActor : public OpActor<DeviceTensor> {
public:
explicit AbstractActor(const std::string &name, const AID *recorder_aid)
: OpActor(name),
recorder_aid_(recorder_aid),
input_datas_num_(0),
input_controls_num_(0),
running_dependent_msg_num_(0) {}
virtual ~AbstractActor() = default;
bool IsActive(int msg_num) override { return msg_num >= running_dependent_msg_num_ ? true : false; }
protected:
friend class GraphScheduler;
// Check whether satisfy the actor running condition.
bool CheckRunningCondition(OpContext<DeviceTensor> *const context) const;
// Erase input data and input controls when finish actor running.
void EraseInput(OpContext<DeviceTensor> *const context);
// The device interface.
std::vector<const DeviceContext *> device_contexts_;
// The id of recorder actor. Send message to it for recording info.
const AID *recorder_aid_;
// The output result arrows of graph output.
std::vector<DataArrowPtr> output_result_arrows_;
// The dependent device tensor stores, the dependent expression is pair<index, AnfNode>.
// Index is the input position, AnfNode is the key of the device tensor store.
std::vector<std::pair<size_t, AnfNodePtr>> device_tensor_store_keys_;
// The dependent input actors.
std::vector<AID> input_data_arrow_aids_;
std::vector<AID> input_control_arrow_aids_;
// The dependent inputs number.
size_t input_datas_num_;
size_t input_controls_num_;
// The dependent messages number of actor running.
int running_dependent_msg_num_;
};
using AbstractActorPtr = std::shared_ptr<AbstractActor>;
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_ABSTRACT_ACTOR_H_

View File

@ -24,6 +24,11 @@ namespace runtime {
const size_t kDeviceTensorNum = 1;
void CopyActor::Init() {
// Check device contexts number.
if (device_contexts_.size() != kDeviceContextsNumTwo) {
MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
}
input_device_tensor_.resize(kDeviceTensorNum);
output_device_tensor_.resize(kDeviceTensorNum);
@ -43,7 +48,7 @@ void CopyActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<Devi
auto &sequential_num = context->sequential_num_;
(void)input_op_datas_[sequential_num].emplace_back(input_data);
// When all the inputs are collected, then allocate memory and callback copy.
if (CheckCopyCondition(context)) {
if (CheckRunningCondition(context)) {
FetchDeviceTensor(context);
SendMemoryAllocReq(context);
}
@ -54,20 +59,20 @@ void CopyActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *
auto &sequential_num = context->sequential_num_;
(void)input_op_controls_[sequential_num].emplace_back(input_control);
// When all the inputs are collected, then allocate memory and callback copy.
if (CheckCopyCondition(context)) {
if (CheckRunningCondition(context)) {
FetchDeviceTensor(context);
SendMemoryAllocReq(context);
}
}
void CopyActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &output_device_tensor_, output_device_context_,
context, GetAID());
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &output_device_tensor_, device_contexts_[1], context,
GetAID());
}
void CopyActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensor_, input_device_context_, context);
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &output_device_tensor_, output_device_context_, context);
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensor_, device_contexts_[0], context);
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &output_device_tensor_, device_contexts_[1], context);
}
void CopyActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
@ -96,50 +101,28 @@ void CopyActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
SendOutput(context);
}
bool CopyActor::CheckCopyCondition(OpContext<DeviceTensor> *const context) const {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
if (data_iter == input_op_datas_.end()) {
return false;
}
if (data_iter->second.size() != input_datas_num_) {
return false;
}
}
if (input_controls_num_ != 0) {
const auto &control_iter = input_op_controls_.find(context->sequential_num_);
if (control_iter == input_op_controls_.end()) {
return false;
}
if (control_iter->second.size() != input_controls_num_) {
return false;
}
}
return true;
}
void CopyActor::FetchDeviceTensor(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(input_device_context_);
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
if (device_tensor_store_key_.second != nullptr) {
input_device_tensor_[0] = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key_.second,
input_device_context_->GetDeviceAddressType());
if (device_tensor_store_keys_.size() > 0) {
input_device_tensor_[0] = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_keys_[0].second.get(),
device_contexts_[0]->GetDeviceAddressType());
if (input_device_tensor_[0] == nullptr) {
std::string error_info =
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key_.second->fullname_with_scope() +
", device type:" + std::to_string(static_cast<int>(input_device_context_->GetDeviceAddressType()));
GetAID().Name() +
" get device tensor store failed: " + device_tensor_store_keys_[0].second->fullname_with_scope() +
", device type:" + std::to_string(static_cast<int>(device_contexts_[0]->GetDeviceAddressType()));
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
output_device_tensor_[0] = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key_.second,
output_device_context_->GetDeviceAddressType());
output_device_tensor_[0] = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_keys_[0].second.get(),
device_contexts_[1]->GetDeviceAddressType());
if (output_device_tensor_[0] == nullptr) {
std::string error_info =
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key_.second->fullname_with_scope() +
", device type:" + std::to_string(static_cast<int>(output_device_context_->GetDeviceAddressType()));
GetAID().Name() +
" get device tensor store failed: " + device_tensor_store_keys_[0].second->fullname_with_scope() +
", device type:" + std::to_string(static_cast<int>(device_contexts_[1]->GetDeviceAddressType()));
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
} else {
@ -178,24 +161,5 @@ void CopyActor::SendOutput(OpContext<DeviceTensor> *const context) const {
}
}
}
void CopyActor::EraseInput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
auto ret = input_op_datas_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input data failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
if (input_controls_num_ != 0) {
auto ret = input_op_controls_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input controls failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -32,18 +32,12 @@ namespace runtime {
using mindspore::device::DeviceContext;
// The copy actor is used to receive the device tensors and control info to copy data between input device tensor and
// output device tensor. The processing flow is RunOpData/RunOpControl -> CheckCopyCondition -> SendMemoryAllocReq
// output device tensor. The processing flow is RunOpData/RunOpControl -> CheckRunningCondition -> SendMemoryAllocReq
// -> OnMemoryAllocFinish -> Copy -> SendMemoryFreeReq -> SendOutput.
class CopyActor : public MemoryAwareActor {
public:
CopyActor(const std::string &name, const AID &memory_manager_aid)
: MemoryAwareActor(name),
memory_manager_aid_(memory_manager_aid),
input_datas_num_(0),
input_controls_num_(0),
input_device_context_(nullptr),
output_device_context_(nullptr),
output_(nullptr) {}
: MemoryAwareActor(name, nullptr, memory_manager_aid), output_(nullptr) {}
~CopyActor() override = default;
void Init() override;
@ -62,34 +56,15 @@ class CopyActor : public MemoryAwareActor {
private:
friend class GraphScheduler;
// Check whether satisfy the condition for copy.
bool CheckCopyCondition(OpContext<DeviceTensor> *const context) const;
// Fetch the device tensor for copy.
void FetchDeviceTensor(OpContext<DeviceTensor> *const context);
// Send output data and output controls when finish copy.
void SendOutput(OpContext<DeviceTensor> *const context) const;
// Erase input data and input controls when finish copy.
void EraseInput(OpContext<DeviceTensor> *const context);
// The id of memory manager actor. Send message to it for alloc and free memory during the copy.
const AID memory_manager_aid_;
// The dependent input data number.
size_t input_datas_num_;
// The dependent input controls number.
size_t input_controls_num_;
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
std::pair<size_t, AnfNode *> device_tensor_store_key_;
// The device interface for copy.
const DeviceContext *input_device_context_;
const DeviceContext *output_device_context_;
// The input device tensor is saved from the input data or fetched by device_tensor_store_key_.
// The input device tensor is saved from the input data or fetched by device_tensor_store_keys_.
std::vector<DeviceTensor *> input_device_tensor_;
// The output device tensor is saved from the output or fetched by device_tensor_store_key_.
// The output device tensor is saved from the output or fetched by device_tensor_store_keys_.
std::vector<DeviceTensor *> output_device_tensor_;
// The output_data_ corresponds to the output_data_arrows_ one by one.

View File

@ -27,6 +27,11 @@
namespace mindspore {
namespace runtime {
void DataSourceActor::Init() {
// Check device contexts number.
if (device_contexts_.size() < kDeviceContextsNumOne) {
MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
}
// Init output data.
for (auto &data_arrow : output_data_arrows_) {
MS_EXCEPTION_IF_NULL(data_arrow);
@ -98,6 +103,11 @@ void DataSourceActor::SendOutput(OpContext<DeviceTensor> *const context) {
}
void DeviceQueueDataSourceActor::Init() {
// Check device contexts number.
if (device_contexts_.size() != kDeviceContextsNumOne) {
MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
}
// Init output data.
for (auto &data_arrow : output_data_arrows_) {
MS_EXCEPTION_IF_NULL(data_arrow);
@ -126,17 +136,18 @@ void DeviceQueueDataSourceActor::FillDataBuffer() {
void DeviceQueueDataSourceActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
auto &device_tensors = buffers_.back();
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors, device_context_, context, GetAID());
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors, device_contexts_[0], context,
GetAID());
}
void DeviceQueueDataSourceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
auto &device_tensors = buffers_.front();
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors, device_context_, context);
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors, device_contexts_[0], context);
}
void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(device_context_);
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
if (buffers_.size() == 0) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty.");
}
@ -151,8 +162,8 @@ void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *co
// Copy data from device queue by data kernel launching.
try {
auto ret = device_context_->LaunchKernel(data_kernel_, launch_info_.inputs_, launch_info_.workspaces_,
launch_info_.outputs_);
auto ret = device_contexts_[0]->LaunchKernel(data_kernel_, launch_info_.inputs_, launch_info_.workspaces_,
launch_info_.outputs_);
if (!ret) {
std::string error_info = "Launch kernel failed: " + data_kernel_->fullname_with_scope();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
@ -178,7 +189,7 @@ void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *co
}
void DeviceQueueDataSourceActor::SendDebugReq(OpContext<DeviceTensor> *const context) {
Async(*debug_aid_, &DebugActor::Debug, data_kernel_, &launch_info_, device_context_, context, &GetAID());
Async(*debug_aid_, &DebugActor::Debug, data_kernel_, &launch_info_, device_contexts_[0], context, &GetAID());
}
void DeviceQueueDataSourceActor::OnDebugFinish(OpContext<DeviceTensor> *const context) {
@ -197,7 +208,7 @@ void DeviceQueueDataSourceActor::SendResult(OpContext<DeviceTensor> *const conte
void DeviceQueueDataSourceActor::SendRecorderInfo(OpContext<DeviceTensor> *const context) {
if (recorder_aid_ != nullptr) {
Async(*recorder_aid_, &RecorderActor::RecordInfo, data_kernel_->fullname_with_scope(), &launch_info_,
device_context_, context);
device_contexts_[0], context);
}
}

View File

@ -41,13 +41,9 @@ using mindspore::kernel::KernelLaunchInfo;
// -> OnMemoryAllocFinish -> SendMemoryFreeReq -> SendOutput.
class DataSourceActor : public DebugAwareActor {
public:
DataSourceActor(const std::string &name, size_t buffer_capacity, const AID memory_manager_aid, const AID *debug_aid,
DataSourceActor(const std::string &name, size_t buffer_capacity, const AID &memory_manager_aid, const AID *debug_aid,
const AID *recorder_aid)
: DebugAwareActor(name),
buffer_capacity_(buffer_capacity),
memory_manager_aid_(memory_manager_aid),
debug_aid_(debug_aid),
recorder_aid_(recorder_aid) {}
: DebugAwareActor(name, recorder_aid, memory_manager_aid, debug_aid), buffer_capacity_(buffer_capacity) {}
virtual ~DataSourceActor() = default;
void Init() override;
@ -76,20 +72,10 @@ class DataSourceActor : public DebugAwareActor {
// Send output to downstream actors to trigger computing after fetching data finished.
void SendOutput(OpContext<DeviceTensor> *const context);
// The output result arrows of graph output.
std::vector<DataArrowPtr> output_result_arrows_;
// The buffers store the device tensors.
std::queue<std::vector<DeviceTensor *>> buffers_;
size_t buffer_capacity_;
// The id of memory manager actor. Send message to it for alloc and free memory during the data processing.
const AID memory_manager_aid_;
// The id of debug actor. Send message to it for debug after the kernel launch.
const AID *debug_aid_;
// The id of recorder actor. Send message to it for recording kernel info after the kernel launch.
const AID *recorder_aid_;
// The output_data_ corresponds to the output_data_arrows_ one by one.
std::vector<OpDataUniquePtr<DeviceTensor>> output_data_;
};
@ -97,10 +83,11 @@ class DataSourceActor : public DebugAwareActor {
// The class represents that the data source is device queue.
class DeviceQueueDataSourceActor : public DataSourceActor {
public:
DeviceQueueDataSourceActor(std::string name, size_t buffer_capacity, const DeviceContext *device_context,
const AID memory_manager_aid, const AID *debug_aid, const AID *recorder_aid)
: DataSourceActor(name, buffer_capacity, memory_manager_aid, debug_aid, recorder_aid),
device_context_(device_context) {}
DeviceQueueDataSourceActor(const std::string &name, size_t buffer_capacity, const DeviceContext *device_context,
const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid)
: DataSourceActor(name, buffer_capacity, memory_manager_aid, debug_aid, recorder_aid) {
(void)device_contexts_.emplace_back(device_context);
}
~DeviceQueueDataSourceActor() override = default;
void Init() override;
@ -126,8 +113,6 @@ class DeviceQueueDataSourceActor : public DataSourceActor {
// The kernel launch info is fetched by the device tensors.
KernelLaunchInfo launch_info_;
const DeviceContext *device_context_;
};
// The class represents that the data source is host queue.
@ -157,8 +142,6 @@ class HostQueueDataSourceActor : public DataSourceActor {
HostTensorQueuePtr host_queue_;
// Input data nodes fetch data from host queue.
std::vector<AnfNodePtr> data_nodes_;
// The device contexts corresponding to the data nodes.
std::vector<const DeviceContext *> device_contexts_;
// The location of the data node in the data source actor.
std::unordered_map<AnfNodePtr, size_t> data_node_position_map_;

View File

@ -25,10 +25,17 @@ namespace runtime {
// The actor represents a set of common debug related operations of actor.
class DebugAwareActor : public MemoryAwareActor {
public:
explicit DebugAwareActor(const std::string &name) : MemoryAwareActor(name) {}
explicit DebugAwareActor(const std::string &name, const AID *recorder_aid, const AID &memory_manager_aid,
const AID *debug_aid)
: MemoryAwareActor(name, recorder_aid, memory_manager_aid), debug_aid_(debug_aid) {}
virtual ~DebugAwareActor() = default;
virtual void SendDebugReq(OpContext<DeviceTensor> *const context) {}
virtual void OnDebugFinish(OpContext<DeviceTensor> *const context) {}
protected:
// The id of debug actor. Send message to it for debug.
const AID *debug_aid_;
};
} // namespace runtime
} // namespace mindspore

View File

@ -25,6 +25,11 @@
namespace mindspore {
namespace runtime {
void KernelActor::Init() {
// Check device contexts number.
if (device_contexts_.size() != kDeviceContextsNumOne) {
MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
}
// Set the number of actor running dependent messages.
running_dependent_msg_num_ = SizeToInt(input_datas_num_ + input_controls_num_);
@ -84,10 +89,10 @@ void KernelActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<De
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
// When all the inputs are collected, then allocate memory and callback launch.
if (CheckLaunchCondition(context)) {
if (CheckRunningCondition(context)) {
// Infer kernel shape and update abstract info for dynamic shape kernel.
if (is_dynamic_shape_) {
device_context_->UpdateDynamicShape(kernel_);
device_contexts_[0]->UpdateDynamicShape(kernel_);
}
FetchInputDeviceTensor(context);
@ -105,10 +110,10 @@ void KernelActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor>
auto &sequential_num = context->sequential_num_;
(void)input_op_controls_[sequential_num].emplace_back(input_control);
// When all the inputs are collected, then allocate memory and callback launch.
if (CheckLaunchCondition(context)) {
if (CheckRunningCondition(context)) {
// Infer kernel shape and update abstract info for dynamic shape kernel.
if (is_dynamic_shape_) {
device_context_->UpdateDynamicShape(kernel_);
device_contexts_[0]->UpdateDynamicShape(kernel_);
}
FetchInputDeviceTensor(context);
@ -130,7 +135,7 @@ void KernelActor::RunOpControlWithInputTensor(AID *const input_control, OpContex
PushInputDeviceTensor(input_tensors);
// When all the inputs are collected, then allocate memory and callback launch.
if (CheckLaunchCondition(context)) {
if (CheckRunningCondition(context)) {
FetchOutputDeviceTensor();
if (memory_alloc_list_.size() > 0) {
SendMemoryAllocReq(context);
@ -181,30 +186,30 @@ void FreeMemory(const std::vector<DeviceTensor *> &free_list, const DeviceContex
void KernelActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
running_dependent_msg_num_ = 1;
if (strategy_ == GraphExecutionStrategy::kPipeline) {
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &memory_alloc_list_, device_context_, context,
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &memory_alloc_list_, device_contexts_[0], context,
GetAID());
} else {
AllocateMemory(memory_alloc_list_, device_context_);
AllocateMemory(memory_alloc_list_, device_contexts_[0]);
}
}
void KernelActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
if (strategy_ == GraphExecutionStrategy::kPipeline) {
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &memory_free_list_, device_context_, context);
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &memory_free_list_, device_contexts_[0], context);
} else {
FreeMemory(memory_free_list_, device_context_);
FreeMemory(memory_free_list_, device_contexts_[0]);
}
}
void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(kernel_);
MS_EXCEPTION_IF_NULL(device_context_);
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
PreLaunchKernel(context);
try {
auto ret = device_context_->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_,
launch_info_.outputs_, is_dynamic_shape_);
auto ret = device_contexts_[0]->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_,
launch_info_.outputs_, is_dynamic_shape_);
if (!ret) {
std::string error_info = "Launch kernel failed: " + kernel_->fullname_with_scope();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
@ -226,7 +231,7 @@ void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
void KernelActor::SendDebugReq(OpContext<DeviceTensor> *const context) {
running_dependent_msg_num_ = 1;
Async(*debug_aid_, &DebugActor::Debug, kernel_, &launch_info_, device_context_, context, &GetAID());
Async(*debug_aid_, &DebugActor::Debug, kernel_, &launch_info_, device_contexts_[0], context, &GetAID());
}
void KernelActor::OnDebugFinish(OpContext<DeviceTensor> *context) {
@ -234,30 +239,6 @@ void KernelActor::OnDebugFinish(OpContext<DeviceTensor> *context) {
PostLaunchKernel(context);
}
bool KernelActor::CheckLaunchCondition(OpContext<DeviceTensor> *const context) const {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
if (data_iter == input_op_datas_.end()) {
return false;
}
if (data_iter->second.size() != input_datas_num_) {
return false;
}
}
if (input_controls_num_ != 0) {
const auto &control_iter = input_op_controls_.find(context->sequential_num_);
if (control_iter == input_op_controls_.end()) {
return false;
}
if (control_iter->second.size() != input_controls_num_) {
return false;
}
}
return true;
}
void KernelActor::PushInputDeviceTensor(const std::vector<TensorPtr> *input_tensors) {
MS_EXCEPTION_IF_NULL(input_tensors);
if (input_tensors->size() != real_input_num_) {
@ -279,24 +260,25 @@ void KernelActor::PushInputDeviceTensor(const std::vector<TensorPtr> *input_tens
void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(input_data);
if ((input_data->data_ == nullptr) || (input_data->data_->DeviceType() == device_context_->GetDeviceAddressType())) {
if ((input_data->data_ == nullptr) ||
(input_data->data_->DeviceType() == device_contexts_[0]->GetDeviceAddressType())) {
return;
}
MS_LOG(DEBUG) << "Copy from device type: " << input_data->data_->DeviceType()
<< " to device type: " << device_context_->GetDeviceAddressType() << " in " << GetAID().Name();
<< " to device type: " << device_contexts_[0]->GetDeviceAddressType() << " in " << GetAID().Name();
if (copy_input_device_tensors_[input_data->index_] == nullptr) {
copy_input_device_tensors_[input_data->index_] = device_context_->CreateDeviceAddress(
copy_input_device_tensors_[input_data->index_] = device_contexts_[0]->CreateDeviceAddress(
nullptr, input_data->data_->GetSize(), input_data->data_->format(), input_data->data_->type_id());
}
// Dynamic shape need update size.
copy_input_device_tensors_[input_data->index_]->SetSize(input_data->data_->GetSize());
if (copy_input_device_tensors_[input_data->index_]->GetPtr() == nullptr) {
if (!device_context_->AllocateMemory(copy_input_device_tensors_[input_data->index_].get(),
copy_input_device_tensors_[input_data->index_]->GetSize())) {
if (!device_contexts_[0]->AllocateMemory(copy_input_device_tensors_[input_data->index_].get(),
copy_input_device_tensors_[input_data->index_]->GetSize())) {
std::string error_info =
"Device(id:" + std::to_string(device_context_->device_context_key().device_id_) +
"Device(id:" + std::to_string(device_contexts_[0]->device_context_key().device_id_) +
") memory isn't enough and alloc failed, actor name: " + GetAID().Name() +
", alloc size: " + std::to_string(copy_input_device_tensors_[input_data->index_]->GetSize());
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
@ -315,7 +297,7 @@ void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(device_context_);
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
if (data_iter != input_op_datas_.end()) {
@ -330,12 +312,12 @@ void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context)
}
for (auto &device_tensor_store_key : device_tensor_store_keys_) {
auto device_tensor =
DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context_->GetDeviceAddressType());
auto device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get(),
device_contexts_[0]->GetDeviceAddressType());
if (device_tensor == nullptr) {
std::string error_info =
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->fullname_with_scope() +
", device type:" + std::to_string(static_cast<int>(device_context_->GetDeviceAddressType()));
", device type:" + std::to_string(static_cast<int>(device_contexts_[0]->GetDeviceAddressType()));
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
}
if (input_device_tensors_[device_tensor_store_key.first] != device_tensor) {
@ -439,8 +421,8 @@ void KernelActor::SendOutput(OpContext<DeviceTensor> *const context) const {
// 4.Send recorder info.
if (recorder_aid_ != nullptr) {
Async(*recorder_aid_, &RecorderActor::RecordInfo, kernel_->fullname_with_scope(), &launch_info_, device_context_,
context);
Async(*recorder_aid_, &RecorderActor::RecordInfo, kernel_->fullname_with_scope(), &launch_info_,
device_contexts_[0], context);
}
// No output.
@ -449,28 +431,5 @@ void KernelActor::SendOutput(OpContext<DeviceTensor> *const context) const {
SET_OPCONTEXT_SUCCESS_RET((*context));
}
}
void KernelActor::EraseInput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
auto ret = input_op_datas_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input data failed: " + GetAID().Name();
// The sequential num may be invalid, can't set the promise value of context.
MS_LOG(ERROR) << error_info << ", sequential_num: " << context->sequential_num_;
return;
}
}
if (input_controls_num_ != 0) {
auto ret = input_op_controls_.erase(context->sequential_num_);
if (ret == 0) {
std::string error_info = "Erase input controls failed: " + GetAID().Name();
// The sequential num may be invalid, can't set the promise value of context.
MS_LOG(ERROR) << error_info << ", sequential_num: " << context->sequential_num_;
return;
}
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -39,30 +39,24 @@ using mindspore::kernel::KernelLaunchInfo;
using mindspore::tensor::TensorPtr;
// The kernel actor is used to receive the device tensors and control info to luanch kernel.
// The processing flow is RunOpData/RunOpControl -> CheckLaunchCondition -> SendMemoryAllocReq
// The processing flow is RunOpData/RunOpControl -> CheckRunningCondition -> SendMemoryAllocReq
// -> OnMemoryAllocFinish -> LaunchKernel -> SendMemoryFreeReq -> SendOutput.
class KernelActor : public DebugAwareActor {
public:
KernelActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context,
const AID memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
GraphExecutionStrategy strategy)
: DebugAwareActor(name),
: DebugAwareActor(name, recorder_aid, memory_manager_aid, debug_aid),
kernel_(kernel),
kernel_info_(nullptr),
is_dynamic_shape_(false),
device_context_(device_context),
memory_manager_aid_(memory_manager_aid),
debug_aid_(debug_aid),
recorder_aid_(recorder_aid),
input_datas_num_(0),
input_controls_num_(0),
real_input_num_(0),
running_dependent_msg_num_(1),
strategy_(strategy) {}
strategy_(strategy) {
(void)device_contexts_.emplace_back(device_context);
}
~KernelActor() override = default;
void Init() override;
bool IsActive(int msg_num) override { return msg_num >= running_dependent_msg_num_ ? true : false; }
// The kernel actor run when receive the input data.
void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;
@ -86,8 +80,6 @@ class KernelActor : public DebugAwareActor {
private:
friend class GraphScheduler;
// Check whether satisfy the condition for launch.
bool CheckLaunchCondition(OpContext<DeviceTensor> *const context) const;
// Fetch the device tensor for launch.
void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context);
void FetchOutputDeviceTensor();
@ -102,45 +94,20 @@ class KernelActor : public DebugAwareActor {
// Send output data and output controls when finish kernel launch.
void SendOutput(OpContext<DeviceTensor> *const context) const;
// Erase input data and input controls when finish kernel launch.
void EraseInput(OpContext<DeviceTensor> *const context);
// The info of kernel.
CNodePtr kernel_;
KernelInfo *kernel_info_;
bool is_dynamic_shape_;
// The device interface of kernel launch.
const DeviceContext *device_context_;
// The id of memory manager actor. Send message to it for alloc and free memory during the kernel launch.
const AID memory_manager_aid_;
// The id of debug actor. Send message to it for debug after the kernel launch.
const AID *debug_aid_;
// The id of recorder actor. Send message to it for recording kernel info after the kernel launch.
const AID *recorder_aid_;
// The dependent input data number.
size_t input_datas_num_;
// The dependent input controls number.
size_t input_controls_num_;
// The real input number of kernel launch.
size_t real_input_num_;
// The dependent messages number of actor running.
int running_dependent_msg_num_;
// The execution strategy of kernel actor.
// In pipeline mode, kernel actor executes asynchronously.
// In step mode, kernel actor executes synchronously.
GraphExecutionStrategy strategy_{GraphExecutionStrategy::kPipeline};
// The dependent input actors.
std::vector<AID> input_data_arrow_aids_;
std::vector<AID> input_control_arrow_aids_;
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
std::vector<std::pair<size_t, AnfNode *>> device_tensor_store_keys_;
// The device tensors for launch.
std::vector<DeviceTensor *> input_device_tensors_;
std::vector<DeviceTensor *> output_device_tensors_;
@ -160,9 +127,6 @@ class KernelActor : public DebugAwareActor {
// The kernel launch info is fetched by the device tensors.
KernelLaunchInfo launch_info_;
// The output result arrows of graph output.
std::vector<DataArrowPtr> output_result_arrows_;
// Cache unique output data by output index to modify the output data effectively.
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_by_output_index_;
// The output_data_ corresponds to the output_data_arrows_ one by one.

View File

@ -86,7 +86,7 @@ void LoopCountActor::RunOpControl(AID *const input_control, OpContext<DeviceTens
MS_EXCEPTION_IF_NULL(context);
auto sequential_num = context->sequential_num_;
(void)input_op_controls_[sequential_num].emplace_back(input_control);
if (CheckLoopCountIncreaseCondition(context)) {
if (CheckRunningCondition(context)) {
IncreaseLoopCount(context);
}
}
@ -102,12 +102,7 @@ void LoopCountActor::OnDebugFinish(OpContext<DeviceTensor> *const context) {
void LoopCountActor::IncreaseLoopCount(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto sequential_num = context->sequential_num_;
auto ret = input_op_controls_.erase(sequential_num);
if (ret == 0) {
std::string error_info = "Erase input controls failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
EraseInput(context);
total_running_count_++;
current_count_++;
@ -165,12 +160,5 @@ void LoopCountActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context)
Async(kernel_aid, &KernelActor::RunOpControl, source_aid, context);
}
}
bool LoopCountActor::CheckLoopCountIncreaseCondition(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto sequential_num = context->sequential_num_;
return input_op_controls_[sequential_num].size() == input_controls_num_;
}
} // namespace runtime
} // namespace mindspore

View File

@ -34,16 +34,12 @@ namespace runtime {
// and decide whether to loop execution by loop count.
class LoopCountActor : public DebugAwareActor {
public:
LoopCountActor(std::string name, size_t loop_count, const AID memory_manager_aid, const AID *debug_aid,
LoopCountActor(const std::string &name, size_t loop_count, const AID &memory_manager_aid, const AID *debug_aid,
const AID *recorder_aid)
: DebugAwareActor(name),
: DebugAwareActor(name, recorder_aid, memory_manager_aid, debug_aid),
loop_count_(loop_count),
current_count_(0),
total_running_count_(0),
input_controls_num_(0),
memory_manager_aid_(memory_manager_aid),
debug_aid_(debug_aid),
recorder_aid_(recorder_aid) {}
total_running_count_(0) {}
~LoopCountActor() override = default;
@ -68,30 +64,17 @@ class LoopCountActor : public DebugAwareActor {
void IncreaseLoopCount(OpContext<DeviceTensor> *const context);
void SendOutput(OpContext<DeviceTensor> *const context);
bool CheckLoopCountIncreaseCondition(OpContext<DeviceTensor> *const context);
// The loop count is constant, the current count is increased after each step running finished.
size_t loop_count_;
size_t current_count_;
// The total running count represents the toal step running count.
size_t total_running_count_;
// The dependent input controls number.
// In the multi-branch output scenario of the control flow, the control of each branch needs to be recorded
// separately with the branch id as the key. When the output has only one branch, the branch id is 0.
size_t input_controls_num_;
// The output controls contain the data source actors and the no input kernel actors and output actor.
std::vector<AID> data_source_aids_;
std::vector<AID> no_input_kernel_aids_;
AID output_aid_;
// The id of memory manager actor. Send message to it for alloc continuous memory before next step running.
const AID memory_manager_aid_;
// The id of debug actor. Send message to it for debug before loop count actor exits.
const AID *debug_aid_;
// The id of recorder actor. Send message to it for clearing recorder info before loop count actor exits.
const AID *recorder_aid_;
// The nodes need continuous memory, which must allocate in the begin of step running. The first bool of pair
// expresses the inputs of node need continuous memory, the second bool of pair expresses the outputs of node need
// continuous memory.
@ -100,7 +83,6 @@ class LoopCountActor : public DebugAwareActor {
std::vector<std::vector<DeviceTensorPtr>> continuous_memory_alloc_list_list_;
std::vector<std::vector<size_t>> size_list_list_;
std::vector<size_t> total_size_list_;
std::vector<const DeviceContext *> device_contexts_;
};
using LoopCountActorPtr = std::shared_ptr<LoopCountActor>;

View File

@ -19,21 +19,27 @@
#include <utility>
#include <string>
#include "mindrt/include/actor/op_actor.h"
#include "runtime/framework/actor/abstract_actor.h"
#include "runtime/framework/device_tensor_store.h"
namespace mindspore {
namespace runtime {
// The actor represents a set of common memory related operations of actor.
class MemoryAwareActor : public OpActor<DeviceTensor> {
class MemoryAwareActor : public AbstractActor {
public:
explicit MemoryAwareActor(std::string name) : OpActor(name) {}
explicit MemoryAwareActor(const std::string &name, const AID *recorder_aid, const AID &memory_manager_aid)
: AbstractActor(name, recorder_aid), memory_manager_aid_(memory_manager_aid) {}
virtual ~MemoryAwareActor() = default;
virtual void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {}
virtual void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {}
virtual void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {}
protected:
friend class GraphScheduler;
// The id of memory manager actor. Send message to it for alloc and free memory.
const AID memory_manager_aid_;
};
} // namespace runtime
} // namespace mindspore

View File

@ -26,6 +26,7 @@
#include "runtime/framework/control_node_parser.h"
#include "runtime/framework/device_tensor_store.h"
#include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/actor/abstract_actor.h"
#include "runtime/hardware/device_context.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/tensor.h"
@ -37,16 +38,15 @@ using mindspore::session::KernelWithIndex;
using mindspore::tensor::TensorPtr;
// The output actor is used to receive the output result of actor which represents the graph output.
class OutputActor : public OpActor<DeviceTensor> {
class OutputActor : public AbstractActor {
public:
OutputActor(std::string name, size_t loop_count, size_t outputs_num, bool need_loop_count)
: OpActor(name),
: AbstractActor(name, nullptr),
loop_count_(loop_count),
current_count_(0),
outputs_num_(outputs_num),
current_outputs_num_(0),
need_loop_count_(need_loop_count),
running_dependent_msg_num_(1) {
need_loop_count_(need_loop_count) {
outputs_.resize(outputs_num);
output_nodes_.resize(outputs_num);
device_contexts_.resize(outputs_num);
@ -54,7 +54,6 @@ class OutputActor : public OpActor<DeviceTensor> {
~OutputActor() override = default;
void Init() override;
bool IsActive(int msg_num) override { return msg_num >= running_dependent_msg_num_ ? true : false; }
// The output actor collects loop count when receive the input control of loop count actor.
void CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *const context);
@ -80,15 +79,9 @@ class OutputActor : public OpActor<DeviceTensor> {
// The outputs.
std::vector<TensorPtr> outputs_;
std::vector<KernelWithIndex> output_nodes_;
std::vector<const DeviceContext *> device_contexts_;
size_t outputs_num_;
size_t current_outputs_num_;
bool need_loop_count_;
// The dependent messages number of actor running.
int running_dependent_msg_num_;
std::vector<std::pair<size_t, AnfNodePtr>> device_tensor_store_keys_;
};
using OutputActorPtr = std::shared_ptr<OutputActor>;

View File

@ -1361,7 +1361,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *const to_actor, const GraphCompi
AnfAlgo::VisitKernelWithReturnType(kernel_with_index.first, SizeToInt(kernel_with_index.second));
if (HasAbstractRef(real_front_node_with_index.first)) {
(void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second,
real_front_node_with_index.first.get());
real_front_node_with_index.first);
return;
}
@ -1388,7 +1388,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *const to_actor, const GraphCompi
auto actor_name = func_graph->ToString();
const auto &from_actor = dynamic_cast<GatherActor *>(FetchActor(actor_name));
if (HasAbstractRef(from_kernel)) {
(void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_node.get());
(void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_node);
return;
}
LinkDataArrowForGatherActor(from_actor, to_actor, {front_node, 0}, to_kernel_with_input_idx);
@ -1409,8 +1409,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *const to_actor, const GraphCompi
to_kernel_with_input_idx);
} else if (IsPersistentDeviceTensor(from_kernel)) {
const auto device_tensor_store_key = FetchFrontNodeByBackendNode(from_kernel, graph);
(void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second,
device_tensor_store_key.get());
(void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, device_tensor_store_key);
} else {
// May exist the from kernel that no need link in the pynative mode.
MS_LOG(DEBUG) << "Invalid from kernel: " << from_kernel->fullname_with_scope();
@ -1437,7 +1436,7 @@ void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &interna
return;
}
if (IsPersistentDeviceTensor(front_output_node)) {
(void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_output_node.get());
(void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_output_node);
return;
}
@ -1486,7 +1485,7 @@ void GraphScheduler::LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *c
auto from_output_index = from_kernel_with_output_idx.second;
auto to_input_index = to_kernel_with_input_idx.second;
if (IsNeedInsertCopyActor(from_actor->device_context_, to_actor->device_context_)) {
if (IsNeedInsertCopyActor(from_actor->device_contexts_[0], to_actor->device_contexts_[0])) {
LinkDataArrowForCopyActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else {
auto to_aid = to_actor->GetAID();
@ -1514,7 +1513,7 @@ void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *const
// Get the position of from kernel in the data source actor.
auto position = from_actor->FetchDataNodePosition(from_kernel);
if (IsNeedInsertCopyActor(from_actor->device_contexts_[position], to_actor->device_context_)) {
if (IsNeedInsertCopyActor(from_actor->device_contexts_[position], to_actor->device_contexts_[0])) {
LinkDataArrowForCopyActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else {
auto to_aid = to_actor->GetAID();
@ -1554,7 +1553,7 @@ void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, Kernel
auto from_output_index = from_kernel_with_output_idx.second;
auto to_input_index = to_kernel_with_input_idx.second;
if (IsNeedInsertCopyActor(from_actor->device_context_, to_actor->device_context_)) {
if (IsNeedInsertCopyActor(from_actor->device_contexts_[0], to_actor->device_contexts_[0])) {
LinkDataArrowForCopyActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else {
auto to_aid = to_actor->GetAID();
@ -1575,7 +1574,7 @@ void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *const from
MS_EXCEPTION_IF_NULL(to_actor);
auto from_kernel = from_kernel_with_output_idx.first;
MS_EXCEPTION_IF_NULL(from_kernel);
auto to_device_context = to_actor->device_context_;
auto to_device_context = to_actor->device_contexts_[0];
MS_EXCEPTION_IF_NULL(to_device_context);
auto from_output_index = from_kernel_with_output_idx.second;
auto to_input_index = to_kernel_with_input_idx.second;
@ -1599,12 +1598,12 @@ void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *const from
if (IsDeviceQueueDSActor(from_kernel)) {
auto real_from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(from_actor);
MS_EXCEPTION_IF_NULL(real_from_actor);
from_device_context = real_from_actor->device_context_;
from_device_context = real_from_actor->device_contexts_[0];
(void)real_from_actor->output_data_arrows_.emplace_back(op_arrow_to_copy);
} else if (IsKernelActor(from_kernel)) {
auto real_from_actor = dynamic_cast<KernelActor *>(from_actor);
MS_EXCEPTION_IF_NULL(real_from_actor);
from_device_context = real_from_actor->device_context_;
from_device_context = real_from_actor->device_contexts_[0];
(void)real_from_actor->output_data_arrows_.emplace_back(op_arrow_to_copy);
} else if (IsHostQueueDSActor(from_kernel)) {
auto real_from_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
@ -1629,8 +1628,8 @@ void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *const from
copy_actor->output_ = to_device_context->CreateDeviceAddress(
nullptr, input_sizes[to_input_index], from_device_tensor->format(), from_device_tensor->type_id());
MS_EXCEPTION_IF_NULL(from_device_context);
copy_actor->input_device_context_ = from_device_context;
copy_actor->output_device_context_ = to_device_context;
(void)copy_actor->device_contexts_.emplace_back(from_device_context);
(void)copy_actor->device_contexts_.emplace_back(to_device_context);
// Update the reference count of device tensor.
UpdateRefCount(from_device_tensor.get());
@ -2060,7 +2059,7 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ke
for (auto &kernel_actor : auto_monad_actors) {
MS_EXCEPTION_IF_NULL(kernel_actor);
for (auto &device_tensor_store_key : kernel_actor->device_tensor_store_keys_) {
auto device_tensors = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second);
auto device_tensors = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get());
if (device_tensors.size() < kNeedUpdateDeviceTensorStoreNum) {
continue;
}
@ -2077,9 +2076,9 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ke
InsertActor(copy_actor.get());
// Set the member of the copy actor.
copy_actor->device_tensor_store_key_ = std::pair<size_t, AnfNode *>(0, device_tensor_store_key.second);
auto input_device_context = kernel_actor->device_context_;
copy_actor->input_device_context_ = input_device_context;
(void)copy_actor->device_tensor_store_keys_.emplace_back(0, device_tensor_store_key.second);
auto input_device_context = kernel_actor->device_contexts_[0];
(void)copy_actor->device_contexts_.emplace_back(input_device_context);
auto another_device_tensor = (device_tensors[0]->DeviceType() == input_device_context->GetDeviceAddressType())
? device_tensors[1]
: device_tensors[0];
@ -2088,7 +2087,7 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ke
const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
{device::kDeviceTypeToName.at(another_device_type), input_device_context->device_context_key().device_id_});
MS_EXCEPTION_IF_NULL(another_device_context);
copy_actor->output_device_context_ = another_device_context;
(void)copy_actor->device_contexts_.emplace_back(another_device_context);
MS_LOG(INFO) << "The kernel actor: " << kernel_actor->GetAID().Name()
<< "has control arrows number:" << kernel_actor->output_control_arrows_.size();
@ -2662,7 +2661,7 @@ bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionSt
const size_t kCopyActorInputDataNum = 1;
auto input_data_num = copy_actor->input_datas_num_;
size_t device_tensor_store_num = (copy_actor->device_tensor_store_key_.second == nullptr) ? 0 : 1;
size_t device_tensor_store_num = copy_actor->device_tensor_store_keys_.size();
if (input_data_num + device_tensor_store_num != kCopyActorInputDataNum) {
MS_LOG(ERROR) << "The input building of " << copy_actor->GetAID().Name()
<< " is wrong, input data num: " << input_data_num
@ -2838,78 +2837,131 @@ void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInf
ofs << "[Device tensor stores]\n";
DumpDeviceTensorStore(graph_compiler_info, ofs);
ofs << "\n\n[Data source actors]\n";
ofs << "\n\n[Data source actors:" << actor_set->data_source_actors_.size() << "]\n";
for (const auto &data_source_actor : actor_set->data_source_actors_) {
DumpDSActor(data_source_actor.get(), ofs);
}
ofs << "\n\n[Kernel actors]\n";
ofs << "\n\n[Kernel actors:" << actor_set->kernel_actors_.size() << "]\n";
for (const auto &kernel_actor : actor_set->kernel_actors_) {
DumpKernelActor(kernel_actor.get(), ofs);
}
ofs << "\n\n[No input kernel actors]\n";
ofs << "\n\n[No input kernel actors:" << actor_set->no_input_kernel_actors_.size() << "]\n";
for (const auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
DumpKernelActor(no_input_kernel_actor.get(), ofs);
}
ofs << "\n\n[Copy actors]\n";
ofs << "\n\n[Copy actors:" << actor_set->copy_actors_.size() << "]\n";
for (const auto &copy_actor : actor_set->copy_actors_) {
DumpCopyActor(copy_actor.get(), ofs);
}
ofs << "\n\n[Gather actors]\n";
ofs << "\n\n[Gather actors:" << actor_set->gather_actors_.size() << "]\n";
for (const auto &gather_actor : actor_set->gather_actors_) {
DumpGatherActor(gather_actor.get(), ofs);
}
ofs << "\n\n[Switch actors]\n";
ofs << "\n\n[Switch actors:" << actor_set->switch_actors_.size() << "]\n";
for (const auto &switch_actor : actor_set->switch_actors_) {
DumpSwitchActor(switch_actor.get(), ofs);
}
ofs << "\n\n[Loop count actor]\n";
const auto &loop_count_actor = actor_set->loop_count_actor_;
ofs << "\n\n[Loop count actor:" << (loop_count_actor != nullptr ? 1 : 0) << "]\n";
if (loop_count_actor != nullptr) {
DumpLoopCountActor(loop_count_actor.get(), ofs);
}
ofs << "\n\n[Output actor]\n";
const auto &output_actor = actor_set->output_actor_;
ofs << "\n\n[Output actor:" << (output_actor != nullptr ? 1 : 0) << "]\n";
if (output_actor != nullptr) {
DumpOutputActor(output_actor.get(), ofs);
}
}
void GraphScheduler::DumpBaseActor(const OpActor<DeviceTensor> *actor, std::ofstream &ofs) const {
void GraphScheduler::DumpAbstractActor(const AbstractActor *actor, std::ofstream &ofs) const {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\t\tdevice_contexts_num:" << actor->device_contexts_.size()
<< "\tdevice_tensor_store_keys_num:" << actor->device_tensor_store_keys_.size()
<< "\tinput_data_arrow_actors_num:" << actor->input_datas_num_
<< "\tinput_control_arrow_actors_num:" << actor->input_controls_num_ << "\n";
ofs << "\t\toutput_data_arrows_num:" << actor->output_data_arrows_.size()
<< "\toutput_control_arrows_num:" << actor->output_control_arrows_.size()
<< "\toutput_result_arrows_num:" << actor->output_result_arrows_.size() << "\n";
if (actor->device_contexts_.size() > 0) {
ofs << "\t\tdevice_contexts:" << actor->device_contexts_.size() << "\n ";
for (const auto &device_context : actor->device_contexts_) {
if (device_context == nullptr) {
ofs << "\t\t\tdevice_context:" << device_context << "\n";
continue;
}
ofs << "\t\t\tdevice_context:" << device_context->device_context_key().ToString() << "\n";
}
}
if (actor->device_tensor_store_keys_.size() > 0) {
ofs << "\t\tdevice_tensor_store_keys:" << actor->device_tensor_store_keys_.size() << "\n ";
for (const auto &device_tensor_store_key : actor->device_tensor_store_keys_) {
MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
ofs << "\t\t\tto_input_index:" << device_tensor_store_key.first
<< "\tfrom_node_name:" << device_tensor_store_key.second->fullname_with_scope() << "\n";
}
}
if (actor->input_data_arrow_aids_.size() > 0) {
ofs << "\t\tinput_data_arrow_actors:" << actor->input_data_arrow_aids_.size() << "\n ";
for (const auto &input_data_arrow_aid : actor->input_data_arrow_aids_) {
ofs << "\t\t\tfrom_actor_name:" << input_data_arrow_aid.Name() << "\n";
}
}
if (actor->input_control_arrow_aids_.size() > 0) {
ofs << "\t\tinput_control_arrow_actors:" << actor->input_control_arrow_aids_.size() << "\n ";
for (const auto &input_control_arrow_aid : actor->input_control_arrow_aids_) {
ofs << "\t\t\tfrom_actor_name:" << input_control_arrow_aid.Name() << "\n";
}
}
const auto &output_data_arrows = actor->output_data_arrows();
ofs << "\t\toutput_data_arrows:" << output_data_arrows.size() << "\n ";
for (const auto &data_arrow : output_data_arrows) {
MS_EXCEPTION_IF_NULL(data_arrow);
ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_
<< "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
<< "\n";
if (output_data_arrows.size() > 0) {
ofs << "\t\toutput_data_arrows:" << output_data_arrows.size() << "\n ";
for (const auto &data_arrow : output_data_arrows) {
MS_EXCEPTION_IF_NULL(data_arrow);
ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_
<< "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
<< "\n";
}
}
const auto &output_control_arrows = actor->output_control_arrows();
ofs << "\t\toutput_control_arrows:" << output_control_arrows.size() << "\n ";
for (const auto &aid : output_control_arrows) {
ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
if (output_control_arrows.size() > 0) {
ofs << "\t\toutput_control_arrows:" << output_control_arrows.size() << "\n ";
for (const auto &aid : output_control_arrows) {
ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
}
}
if (actor->output_result_arrows_.size() > 0) {
ofs << "\t\toutput_result_arrows:" << actor->output_result_arrows_.size() << "\n ";
for (const auto &result_arrow : actor->output_result_arrows_) {
MS_EXCEPTION_IF_NULL(result_arrow);
ofs << "\t\t\tfrom_output_index:" << result_arrow->from_output_index_
<< "\tto_actor_name:" << result_arrow->to_op_id_.Name()
<< "\toutput_node_position:" << result_arrow->to_input_index_ << "\n";
}
}
}
void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const {
MS_EXCEPTION_IF_NULL(actor);
const auto &actor_name = actor->GetAID().Name();
ofs << "\tactor_name:" << actor_name << "\n";
if (actor_name.find("_DeviceDSActor") != string::npos) {
// Dump the member info of device queue data source actor.
const auto &device_queue_ds_actor = dynamic_cast<const DeviceQueueDataSourceActor *>(actor);
MS_EXCEPTION_IF_NULL(device_queue_ds_actor->device_context_);
ofs << "\tactor_name:" << actor_name
<< "\tdevice_context:" << device_queue_ds_actor->device_context_->device_context_key().ToString() << "\n";
const auto &data_kernel = device_queue_ds_actor->data_kernel_;
MS_EXCEPTION_IF_NULL(data_kernel);
ofs << "\t\tdata_kernel_name:" << data_kernel->fullname_with_scope()
@ -2923,7 +2975,6 @@ void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &of
}
} else if (actor_name.find("_HostDSActor") != string::npos) {
// Dump the member info of host queue data source actor.
ofs << "\tactor_name:" << actor_name << "\n";
const auto &host_queue_ds_actor = dynamic_cast<const HostQueueDataSourceActor *>(actor);
ofs << "\t\tdata_nodes:" << host_queue_ds_actor->data_nodes_.size() << "\n";
for (size_t i = 0; i < host_queue_ds_actor->data_nodes_.size(); ++i) {
@ -2933,27 +2984,18 @@ void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &of
MS_EXCEPTION_IF_NULL(device_tensor);
ofs << "\t\t\tnode_order_number:" << i << "\tnode_name:" << data_node->fullname_with_scope()
<< "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
<< "\toriginal_ref_count:" << device_tensor->original_ref_count()
<< "\tdevice_context:" << host_queue_ds_actor->device_contexts_[i]->device_context_key().ToString() << "\n";
<< "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n";
}
}
DumpBaseActor(actor, ofs);
ofs << "\t\toutput_result_arrows:" << actor->output_result_arrows_.size() << "\n ";
for (const auto &result_arrow : actor->output_result_arrows_) {
MS_EXCEPTION_IF_NULL(result_arrow);
ofs << "\t\t\tfrom_output_index:" << result_arrow->from_output_index_
<< "\tto_actor_name:" << result_arrow->to_op_id_.Name()
<< "\toutput_node_position:" << result_arrow->to_input_index_ << "\n";
}
DumpAbstractActor(actor, ofs);
ofs << "\n";
}
void GraphScheduler::DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_
<< "\tinput_controls_num:" << actor->input_controls_num_ << "\n";
ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_ << "\n";
DumpAbstractActor(actor, ofs);
ofs << "\t\toutput_control_arrows:" << (actor->data_source_aids_.size() + actor->no_input_kernel_aids_.size() + 1)
<< "\n ";
@ -2975,16 +3017,12 @@ void GraphScheduler::DumpLoopCountActor(const LoopCountActor *actor, std::ofstre
void GraphScheduler::DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const {
MS_EXCEPTION_IF_NULL(actor);
MS_EXCEPTION_IF_NULL(actor->device_context_);
ofs << "\tactor_name:" << actor->GetAID().Name()
<< "\tdevice_context:" << actor->device_context_->device_context_key().ToString()
<< "\tinput_data_num:" << actor->input_datas_num_ << "\tinput_controls_num:" << actor->input_controls_num_
<< "\n";
ofs << "\tactor_name:" << actor->GetAID().Name() << "\n";
const auto &kernel = actor->kernel_;
MS_EXCEPTION_IF_NULL(kernel);
ofs << "\t\tkernel_name:" << kernel->fullname_with_scope() << "\tinput_number:" << AnfAlgo::GetInputTensorNum(kernel)
<< "\toutput_number:" << AnfAlgo::GetOutputTensorNum(kernel) << "\n";
ofs << "\t\tkernel_name:" << kernel->fullname_with_scope() << "\tinputs_num:" << AnfAlgo::GetInputTensorNum(kernel)
<< "\toutputs_num:" << AnfAlgo::GetOutputTensorNum(kernel) << "\n";
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel); ++i) {
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
MS_EXCEPTION_IF_NULL(device_tensor);
@ -2992,22 +3030,7 @@ void GraphScheduler::DumpKernelActor(const KernelActor *actor, std::ofstream &of
<< "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
}
ofs << "\t\tdevice_tensor_stores:" << actor->device_tensor_store_keys_.size() << "\n ";
for (const auto &device_tensor_store_key : actor->device_tensor_store_keys_) {
MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
ofs << "\t\t\tto_input_index:" << device_tensor_store_key.first
<< "\tfrom_node_name:" << device_tensor_store_key.second->fullname_with_scope() << "\n";
}
DumpBaseActor(actor, ofs);
ofs << "\t\toutput_result_arrows:" << actor->output_result_arrows_.size() << "\n ";
for (const auto &result_arrow : actor->output_result_arrows_) {
MS_EXCEPTION_IF_NULL(result_arrow);
ofs << "\t\t\tfrom_output_index:" << result_arrow->from_output_index_
<< "\tto_actor_name:" << result_arrow->to_op_id_.Name()
<< "\toutput_node_position:" << result_arrow->to_input_index_ << "\n";
}
DumpAbstractActor(actor, ofs);
ofs << "\n";
}
@ -3015,33 +3038,12 @@ void GraphScheduler::DumpOutputActor(const OutputActor *actor, std::ofstream &of
MS_EXCEPTION_IF_NULL(actor);
ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_
<< "\toutputs_num:" << actor->outputs_num_ << "\n";
ofs << "\t\tdevice_tensor_store_keys:" << actor->device_tensor_store_keys_.size() << "\n ";
for (const auto &device_tensor_store_key : actor->device_tensor_store_keys_) {
MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
ofs << "\t\t\toutput_node_position:" << device_tensor_store_key.first
<< "\toutput_node_name:" << device_tensor_store_key.second->fullname_with_scope() << "\n";
}
ofs << "\t\tdevice_contexts:" << actor->device_contexts_.size() << "\n ";
for (const auto &device_context : actor->device_contexts_) {
if (device_context == nullptr) {
ofs << "\t\t\tdevice_context:" << device_context << "\n";
continue;
}
ofs << "\t\t\tdevice_context:" << device_context->device_context_key().ToString() << "\n";
}
DumpAbstractActor(actor, ofs);
}
void GraphScheduler::DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) const {
MS_EXCEPTION_IF_NULL(actor);
MS_EXCEPTION_IF_NULL(actor->input_device_context_);
MS_EXCEPTION_IF_NULL(actor->output_device_context_);
ofs << "\tactor_name:" << actor->GetAID().Name()
<< "\tinput_device_context:" << actor->input_device_context_->device_context_key().ToString()
<< "\toutput_device_context:" << actor->output_device_context_->device_context_key().ToString()
<< "\tinput_data_num:" << actor->input_datas_num_ << "\tinput_controls_num:" << actor->input_controls_num_
<< "\n";
ofs << "\tactor_name:" << actor->GetAID().Name() << "\n";
auto device_tensor = actor->output_;
if (device_tensor != nullptr) {
@ -3049,13 +3051,7 @@ void GraphScheduler::DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) c
<< "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
}
if (actor->device_tensor_store_key_.second != nullptr) {
ofs << "\t\tdevice_tensor_stores:" << 1 << "\n ";
ofs << "\t\t\tto_input_index:" << actor->device_tensor_store_key_.first
<< "\tfrom_node_name:" << actor->device_tensor_store_key_.second->fullname_with_scope() << "\n";
}
DumpBaseActor(actor, ofs);
DumpAbstractActor(actor, ofs);
ofs << "\n";
}
@ -3144,6 +3140,7 @@ void GraphScheduler::DumpGatherActor(const GatherActor *actor, std::ofstream &of
for (const auto &control_arrow : actor->output_control_arrows_) {
ofs << "\t\t\tto_actor_name:" << control_arrow;
}
ofs << "\n";
}
void GraphScheduler::DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) const {
@ -3191,6 +3188,7 @@ void GraphScheduler::DumpSwitchActor(const SwitchActor *actor, std::ofstream &of
ofs << "\t\t\t\t from index:" << arrow << '\n';
}
}
ofs << "\n";
}
} // namespace runtime
} // namespace mindspore

View File

@ -302,7 +302,7 @@ class GraphScheduler {
// Display the actor information of corresponding kernel graph.
void DumpActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) const;
void DumpBaseActor(const OpActor<DeviceTensor> *actor, std::ofstream &ofs) const;
void DumpAbstractActor(const AbstractActor *actor, std::ofstream &ofs) const;
void DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const;
void DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const;
void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const;