forked from mindspore-Ecosystem/mindspore
unified runtime add abstract actor and optimize code
This commit is contained in:
parent
0f74d2a704
commit
15e6ace23b
|
@ -0,0 +1,69 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/framework/actor/abstract_actor.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
bool AbstractActor::CheckRunningCondition(OpContext<DeviceTensor> *const context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_datas_num_ != 0) {
|
||||
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
|
||||
if (data_iter == input_op_datas_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (data_iter->second.size() != input_datas_num_) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (input_controls_num_ != 0) {
|
||||
const auto &control_iter = input_op_controls_.find(context->sequential_num_);
|
||||
if (control_iter == input_op_controls_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (control_iter->second.size() != input_controls_num_) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void AbstractActor::EraseInput(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_datas_num_ != 0) {
|
||||
auto ret = input_op_datas_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input data failed: " + GetAID().Name();
|
||||
// The sequential num may be invalid, can't set the promise value of context.
|
||||
MS_LOG(ERROR) << error_info << ", sequential_num: " << context->sequential_num_;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (input_controls_num_ != 0) {
|
||||
auto ret = input_op_controls_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input controls failed: " + GetAID().Name();
|
||||
// The sequential num may be invalid, can't set the promise value of context.
|
||||
MS_LOG(ERROR) << error_info << ", sequential_num: " << context->sequential_num_;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_ABSTRACT_ACTOR_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_ABSTRACT_ACTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "mindrt/include/actor/op_actor.h"
|
||||
#include "runtime/framework/device_tensor_store.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using mindspore::device::DeviceContext;
|
||||
|
||||
const size_t kDeviceContextsNumOne = 1;
|
||||
const size_t kDeviceContextsNumTwo = 2;
|
||||
|
||||
// The abstract common attributes of actors. The actor inheritance relationship: OpActor --> AbstractActor -->
|
||||
// MemoryAwareActor --> DebugAwareActor --> KernelActor/DataSourceActor/CopyActor/LoopCountActor/OutputActor.
|
||||
class AbstractActor : public OpActor<DeviceTensor> {
|
||||
public:
|
||||
explicit AbstractActor(const std::string &name, const AID *recorder_aid)
|
||||
: OpActor(name),
|
||||
recorder_aid_(recorder_aid),
|
||||
input_datas_num_(0),
|
||||
input_controls_num_(0),
|
||||
running_dependent_msg_num_(0) {}
|
||||
virtual ~AbstractActor() = default;
|
||||
|
||||
bool IsActive(int msg_num) override { return msg_num >= running_dependent_msg_num_ ? true : false; }
|
||||
|
||||
protected:
|
||||
friend class GraphScheduler;
|
||||
|
||||
// Check whether satisfy the actor running condition.
|
||||
bool CheckRunningCondition(OpContext<DeviceTensor> *const context) const;
|
||||
// Erase input data and input controls when finish actor running.
|
||||
void EraseInput(OpContext<DeviceTensor> *const context);
|
||||
|
||||
// The device interface.
|
||||
std::vector<const DeviceContext *> device_contexts_;
|
||||
|
||||
// The id of recorder actor. Send message to it for recording info.
|
||||
const AID *recorder_aid_;
|
||||
|
||||
// The output result arrows of graph output.
|
||||
std::vector<DataArrowPtr> output_result_arrows_;
|
||||
|
||||
// The dependent device tensor stores, the dependent expression is pair<index, AnfNode>.
|
||||
// Index is the input position, AnfNode is the key of the device tensor store.
|
||||
std::vector<std::pair<size_t, AnfNodePtr>> device_tensor_store_keys_;
|
||||
|
||||
// The dependent input actors.
|
||||
std::vector<AID> input_data_arrow_aids_;
|
||||
std::vector<AID> input_control_arrow_aids_;
|
||||
// The dependent inputs number.
|
||||
size_t input_datas_num_;
|
||||
size_t input_controls_num_;
|
||||
|
||||
// The dependent messages number of actor running.
|
||||
int running_dependent_msg_num_;
|
||||
};
|
||||
|
||||
using AbstractActorPtr = std::shared_ptr<AbstractActor>;
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_ABSTRACT_ACTOR_H_
|
|
@ -24,6 +24,11 @@ namespace runtime {
|
|||
const size_t kDeviceTensorNum = 1;
|
||||
|
||||
void CopyActor::Init() {
|
||||
// Check device contexts number.
|
||||
if (device_contexts_.size() != kDeviceContextsNumTwo) {
|
||||
MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
|
||||
}
|
||||
|
||||
input_device_tensor_.resize(kDeviceTensorNum);
|
||||
output_device_tensor_.resize(kDeviceTensorNum);
|
||||
|
||||
|
@ -43,7 +48,7 @@ void CopyActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<Devi
|
|||
auto &sequential_num = context->sequential_num_;
|
||||
(void)input_op_datas_[sequential_num].emplace_back(input_data);
|
||||
// When all the inputs are collected, then allocate memory and callback copy.
|
||||
if (CheckCopyCondition(context)) {
|
||||
if (CheckRunningCondition(context)) {
|
||||
FetchDeviceTensor(context);
|
||||
SendMemoryAllocReq(context);
|
||||
}
|
||||
|
@ -54,20 +59,20 @@ void CopyActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *
|
|||
auto &sequential_num = context->sequential_num_;
|
||||
(void)input_op_controls_[sequential_num].emplace_back(input_control);
|
||||
// When all the inputs are collected, then allocate memory and callback copy.
|
||||
if (CheckCopyCondition(context)) {
|
||||
if (CheckRunningCondition(context)) {
|
||||
FetchDeviceTensor(context);
|
||||
SendMemoryAllocReq(context);
|
||||
}
|
||||
}
|
||||
|
||||
void CopyActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &output_device_tensor_, output_device_context_,
|
||||
context, GetAID());
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &output_device_tensor_, device_contexts_[1], context,
|
||||
GetAID());
|
||||
}
|
||||
|
||||
void CopyActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensor_, input_device_context_, context);
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &output_device_tensor_, output_device_context_, context);
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensor_, device_contexts_[0], context);
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &output_device_tensor_, device_contexts_[1], context);
|
||||
}
|
||||
|
||||
void CopyActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
|
||||
|
@ -96,50 +101,28 @@ void CopyActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
|
|||
SendOutput(context);
|
||||
}
|
||||
|
||||
bool CopyActor::CheckCopyCondition(OpContext<DeviceTensor> *const context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_datas_num_ != 0) {
|
||||
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
|
||||
if (data_iter == input_op_datas_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (data_iter->second.size() != input_datas_num_) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (input_controls_num_ != 0) {
|
||||
const auto &control_iter = input_op_controls_.find(context->sequential_num_);
|
||||
if (control_iter == input_op_controls_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (control_iter->second.size() != input_controls_num_) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void CopyActor::FetchDeviceTensor(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
MS_EXCEPTION_IF_NULL(input_device_context_);
|
||||
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
|
||||
|
||||
if (device_tensor_store_key_.second != nullptr) {
|
||||
input_device_tensor_[0] = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key_.second,
|
||||
input_device_context_->GetDeviceAddressType());
|
||||
if (device_tensor_store_keys_.size() > 0) {
|
||||
input_device_tensor_[0] = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_keys_[0].second.get(),
|
||||
device_contexts_[0]->GetDeviceAddressType());
|
||||
if (input_device_tensor_[0] == nullptr) {
|
||||
std::string error_info =
|
||||
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key_.second->fullname_with_scope() +
|
||||
", device type:" + std::to_string(static_cast<int>(input_device_context_->GetDeviceAddressType()));
|
||||
GetAID().Name() +
|
||||
" get device tensor store failed: " + device_tensor_store_keys_[0].second->fullname_with_scope() +
|
||||
", device type:" + std::to_string(static_cast<int>(device_contexts_[0]->GetDeviceAddressType()));
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
|
||||
output_device_tensor_[0] = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key_.second,
|
||||
output_device_context_->GetDeviceAddressType());
|
||||
output_device_tensor_[0] = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_keys_[0].second.get(),
|
||||
device_contexts_[1]->GetDeviceAddressType());
|
||||
if (output_device_tensor_[0] == nullptr) {
|
||||
std::string error_info =
|
||||
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key_.second->fullname_with_scope() +
|
||||
", device type:" + std::to_string(static_cast<int>(output_device_context_->GetDeviceAddressType()));
|
||||
GetAID().Name() +
|
||||
" get device tensor store failed: " + device_tensor_store_keys_[0].second->fullname_with_scope() +
|
||||
", device type:" + std::to_string(static_cast<int>(device_contexts_[1]->GetDeviceAddressType()));
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
} else {
|
||||
|
@ -178,24 +161,5 @@ void CopyActor::SendOutput(OpContext<DeviceTensor> *const context) const {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CopyActor::EraseInput(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_datas_num_ != 0) {
|
||||
auto ret = input_op_datas_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input data failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
}
|
||||
|
||||
if (input_controls_num_ != 0) {
|
||||
auto ret = input_op_controls_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input controls failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,18 +32,12 @@ namespace runtime {
|
|||
using mindspore::device::DeviceContext;
|
||||
|
||||
// The copy actor is used to receive the device tensors and control info to copy data between input device tensor and
|
||||
// output device tensor. The processing flow is RunOpData/RunOpControl -> CheckCopyCondition -> SendMemoryAllocReq
|
||||
// output device tensor. The processing flow is RunOpData/RunOpControl -> CheckRunningCondition -> SendMemoryAllocReq
|
||||
// -> OnMemoryAllocFinish -> Copy -> SendMemoryFreeReq -> SendOutput.
|
||||
class CopyActor : public MemoryAwareActor {
|
||||
public:
|
||||
CopyActor(const std::string &name, const AID &memory_manager_aid)
|
||||
: MemoryAwareActor(name),
|
||||
memory_manager_aid_(memory_manager_aid),
|
||||
input_datas_num_(0),
|
||||
input_controls_num_(0),
|
||||
input_device_context_(nullptr),
|
||||
output_device_context_(nullptr),
|
||||
output_(nullptr) {}
|
||||
: MemoryAwareActor(name, nullptr, memory_manager_aid), output_(nullptr) {}
|
||||
~CopyActor() override = default;
|
||||
|
||||
void Init() override;
|
||||
|
@ -62,34 +56,15 @@ class CopyActor : public MemoryAwareActor {
|
|||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
// Check whether satisfy the condition for copy.
|
||||
bool CheckCopyCondition(OpContext<DeviceTensor> *const context) const;
|
||||
// Fetch the device tensor for copy.
|
||||
void FetchDeviceTensor(OpContext<DeviceTensor> *const context);
|
||||
|
||||
// Send output data and output controls when finish copy.
|
||||
void SendOutput(OpContext<DeviceTensor> *const context) const;
|
||||
// Erase input data and input controls when finish copy.
|
||||
void EraseInput(OpContext<DeviceTensor> *const context);
|
||||
|
||||
// The id of memory manager actor. Send message to it for alloc and free memory during the copy.
|
||||
const AID memory_manager_aid_;
|
||||
|
||||
// The dependent input data number.
|
||||
size_t input_datas_num_;
|
||||
// The dependent input controls number.
|
||||
size_t input_controls_num_;
|
||||
|
||||
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
|
||||
std::pair<size_t, AnfNode *> device_tensor_store_key_;
|
||||
|
||||
// The device interface for copy.
|
||||
const DeviceContext *input_device_context_;
|
||||
const DeviceContext *output_device_context_;
|
||||
|
||||
// The input device tensor is saved from the input data or fetched by device_tensor_store_key_.
|
||||
// The input device tensor is saved from the input data or fetched by device_tensor_store_keys_.
|
||||
std::vector<DeviceTensor *> input_device_tensor_;
|
||||
// The output device tensor is saved from the output or fetched by device_tensor_store_key_.
|
||||
// The output device tensor is saved from the output or fetched by device_tensor_store_keys_.
|
||||
std::vector<DeviceTensor *> output_device_tensor_;
|
||||
|
||||
// The output_data_ corresponds to the output_data_arrows_ one by one.
|
||||
|
|
|
@ -27,6 +27,11 @@
|
|||
namespace mindspore {
|
||||
namespace runtime {
|
||||
void DataSourceActor::Init() {
|
||||
// Check device contexts number.
|
||||
if (device_contexts_.size() < kDeviceContextsNumOne) {
|
||||
MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
|
||||
}
|
||||
|
||||
// Init output data.
|
||||
for (auto &data_arrow : output_data_arrows_) {
|
||||
MS_EXCEPTION_IF_NULL(data_arrow);
|
||||
|
@ -98,6 +103,11 @@ void DataSourceActor::SendOutput(OpContext<DeviceTensor> *const context) {
|
|||
}
|
||||
|
||||
void DeviceQueueDataSourceActor::Init() {
|
||||
// Check device contexts number.
|
||||
if (device_contexts_.size() != kDeviceContextsNumOne) {
|
||||
MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
|
||||
}
|
||||
|
||||
// Init output data.
|
||||
for (auto &data_arrow : output_data_arrows_) {
|
||||
MS_EXCEPTION_IF_NULL(data_arrow);
|
||||
|
@ -126,17 +136,18 @@ void DeviceQueueDataSourceActor::FillDataBuffer() {
|
|||
|
||||
void DeviceQueueDataSourceActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
|
||||
auto &device_tensors = buffers_.back();
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors, device_context_, context, GetAID());
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors, device_contexts_[0], context,
|
||||
GetAID());
|
||||
}
|
||||
|
||||
void DeviceQueueDataSourceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
|
||||
auto &device_tensors = buffers_.front();
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors, device_context_, context);
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors, device_contexts_[0], context);
|
||||
}
|
||||
|
||||
void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
MS_EXCEPTION_IF_NULL(device_context_);
|
||||
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
|
||||
if (buffers_.size() == 0) {
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty.");
|
||||
}
|
||||
|
@ -151,8 +162,8 @@ void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *co
|
|||
|
||||
// Copy data from device queue by data kernel launching.
|
||||
try {
|
||||
auto ret = device_context_->LaunchKernel(data_kernel_, launch_info_.inputs_, launch_info_.workspaces_,
|
||||
launch_info_.outputs_);
|
||||
auto ret = device_contexts_[0]->LaunchKernel(data_kernel_, launch_info_.inputs_, launch_info_.workspaces_,
|
||||
launch_info_.outputs_);
|
||||
if (!ret) {
|
||||
std::string error_info = "Launch kernel failed: " + data_kernel_->fullname_with_scope();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
|
@ -178,7 +189,7 @@ void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *co
|
|||
}
|
||||
|
||||
void DeviceQueueDataSourceActor::SendDebugReq(OpContext<DeviceTensor> *const context) {
|
||||
Async(*debug_aid_, &DebugActor::Debug, data_kernel_, &launch_info_, device_context_, context, &GetAID());
|
||||
Async(*debug_aid_, &DebugActor::Debug, data_kernel_, &launch_info_, device_contexts_[0], context, &GetAID());
|
||||
}
|
||||
|
||||
void DeviceQueueDataSourceActor::OnDebugFinish(OpContext<DeviceTensor> *const context) {
|
||||
|
@ -197,7 +208,7 @@ void DeviceQueueDataSourceActor::SendResult(OpContext<DeviceTensor> *const conte
|
|||
void DeviceQueueDataSourceActor::SendRecorderInfo(OpContext<DeviceTensor> *const context) {
|
||||
if (recorder_aid_ != nullptr) {
|
||||
Async(*recorder_aid_, &RecorderActor::RecordInfo, data_kernel_->fullname_with_scope(), &launch_info_,
|
||||
device_context_, context);
|
||||
device_contexts_[0], context);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -41,13 +41,9 @@ using mindspore::kernel::KernelLaunchInfo;
|
|||
// -> OnMemoryAllocFinish -> SendMemoryFreeReq -> SendOutput.
|
||||
class DataSourceActor : public DebugAwareActor {
|
||||
public:
|
||||
DataSourceActor(const std::string &name, size_t buffer_capacity, const AID memory_manager_aid, const AID *debug_aid,
|
||||
DataSourceActor(const std::string &name, size_t buffer_capacity, const AID &memory_manager_aid, const AID *debug_aid,
|
||||
const AID *recorder_aid)
|
||||
: DebugAwareActor(name),
|
||||
buffer_capacity_(buffer_capacity),
|
||||
memory_manager_aid_(memory_manager_aid),
|
||||
debug_aid_(debug_aid),
|
||||
recorder_aid_(recorder_aid) {}
|
||||
: DebugAwareActor(name, recorder_aid, memory_manager_aid, debug_aid), buffer_capacity_(buffer_capacity) {}
|
||||
virtual ~DataSourceActor() = default;
|
||||
|
||||
void Init() override;
|
||||
|
@ -76,20 +72,10 @@ class DataSourceActor : public DebugAwareActor {
|
|||
// Send output to downstream actors to trigger computing after fetching data finished.
|
||||
void SendOutput(OpContext<DeviceTensor> *const context);
|
||||
|
||||
// The output result arrows of graph output.
|
||||
std::vector<DataArrowPtr> output_result_arrows_;
|
||||
|
||||
// The buffers store the device tensors.
|
||||
std::queue<std::vector<DeviceTensor *>> buffers_;
|
||||
size_t buffer_capacity_;
|
||||
|
||||
// The id of memory manager actor. Send message to it for alloc and free memory during the data processing.
|
||||
const AID memory_manager_aid_;
|
||||
// The id of debug actor. Send message to it for debug after the kernel launch.
|
||||
const AID *debug_aid_;
|
||||
// The id of recorder actor. Send message to it for recording kernel info after the kernel launch.
|
||||
const AID *recorder_aid_;
|
||||
|
||||
// The output_data_ corresponds to the output_data_arrows_ one by one.
|
||||
std::vector<OpDataUniquePtr<DeviceTensor>> output_data_;
|
||||
};
|
||||
|
@ -97,10 +83,11 @@ class DataSourceActor : public DebugAwareActor {
|
|||
// The class represents that the data source is device queue.
|
||||
class DeviceQueueDataSourceActor : public DataSourceActor {
|
||||
public:
|
||||
DeviceQueueDataSourceActor(std::string name, size_t buffer_capacity, const DeviceContext *device_context,
|
||||
const AID memory_manager_aid, const AID *debug_aid, const AID *recorder_aid)
|
||||
: DataSourceActor(name, buffer_capacity, memory_manager_aid, debug_aid, recorder_aid),
|
||||
device_context_(device_context) {}
|
||||
DeviceQueueDataSourceActor(const std::string &name, size_t buffer_capacity, const DeviceContext *device_context,
|
||||
const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid)
|
||||
: DataSourceActor(name, buffer_capacity, memory_manager_aid, debug_aid, recorder_aid) {
|
||||
(void)device_contexts_.emplace_back(device_context);
|
||||
}
|
||||
~DeviceQueueDataSourceActor() override = default;
|
||||
|
||||
void Init() override;
|
||||
|
@ -126,8 +113,6 @@ class DeviceQueueDataSourceActor : public DataSourceActor {
|
|||
|
||||
// The kernel launch info is fetched by the device tensors.
|
||||
KernelLaunchInfo launch_info_;
|
||||
|
||||
const DeviceContext *device_context_;
|
||||
};
|
||||
|
||||
// The class represents that the data source is host queue.
|
||||
|
@ -157,8 +142,6 @@ class HostQueueDataSourceActor : public DataSourceActor {
|
|||
HostTensorQueuePtr host_queue_;
|
||||
// Input data nodes fetch data from host queue.
|
||||
std::vector<AnfNodePtr> data_nodes_;
|
||||
// The device contexts corresponding to the data nodes.
|
||||
std::vector<const DeviceContext *> device_contexts_;
|
||||
|
||||
// The location of the data node in the data source actor.
|
||||
std::unordered_map<AnfNodePtr, size_t> data_node_position_map_;
|
||||
|
|
|
@ -25,10 +25,17 @@ namespace runtime {
|
|||
// The actor represents a set of common debug related operations of actor.
|
||||
class DebugAwareActor : public MemoryAwareActor {
|
||||
public:
|
||||
explicit DebugAwareActor(const std::string &name) : MemoryAwareActor(name) {}
|
||||
explicit DebugAwareActor(const std::string &name, const AID *recorder_aid, const AID &memory_manager_aid,
|
||||
const AID *debug_aid)
|
||||
: MemoryAwareActor(name, recorder_aid, memory_manager_aid), debug_aid_(debug_aid) {}
|
||||
virtual ~DebugAwareActor() = default;
|
||||
|
||||
virtual void SendDebugReq(OpContext<DeviceTensor> *const context) {}
|
||||
virtual void OnDebugFinish(OpContext<DeviceTensor> *const context) {}
|
||||
|
||||
protected:
|
||||
// The id of debug actor. Send message to it for debug.
|
||||
const AID *debug_aid_;
|
||||
};
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,6 +25,11 @@
|
|||
namespace mindspore {
|
||||
namespace runtime {
|
||||
void KernelActor::Init() {
|
||||
// Check device contexts number.
|
||||
if (device_contexts_.size() != kDeviceContextsNumOne) {
|
||||
MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
|
||||
}
|
||||
|
||||
// Set the number of actor running dependent messages.
|
||||
running_dependent_msg_num_ = SizeToInt(input_datas_num_ + input_controls_num_);
|
||||
|
||||
|
@ -84,10 +89,10 @@ void KernelActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<De
|
|||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
// When all the inputs are collected, then allocate memory and callback launch.
|
||||
if (CheckLaunchCondition(context)) {
|
||||
if (CheckRunningCondition(context)) {
|
||||
// Infer kernel shape and update abstract info for dynamic shape kernel.
|
||||
if (is_dynamic_shape_) {
|
||||
device_context_->UpdateDynamicShape(kernel_);
|
||||
device_contexts_[0]->UpdateDynamicShape(kernel_);
|
||||
}
|
||||
|
||||
FetchInputDeviceTensor(context);
|
||||
|
@ -105,10 +110,10 @@ void KernelActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor>
|
|||
auto &sequential_num = context->sequential_num_;
|
||||
(void)input_op_controls_[sequential_num].emplace_back(input_control);
|
||||
// When all the inputs are collected, then allocate memory and callback launch.
|
||||
if (CheckLaunchCondition(context)) {
|
||||
if (CheckRunningCondition(context)) {
|
||||
// Infer kernel shape and update abstract info for dynamic shape kernel.
|
||||
if (is_dynamic_shape_) {
|
||||
device_context_->UpdateDynamicShape(kernel_);
|
||||
device_contexts_[0]->UpdateDynamicShape(kernel_);
|
||||
}
|
||||
|
||||
FetchInputDeviceTensor(context);
|
||||
|
@ -130,7 +135,7 @@ void KernelActor::RunOpControlWithInputTensor(AID *const input_control, OpContex
|
|||
|
||||
PushInputDeviceTensor(input_tensors);
|
||||
// When all the inputs are collected, then allocate memory and callback launch.
|
||||
if (CheckLaunchCondition(context)) {
|
||||
if (CheckRunningCondition(context)) {
|
||||
FetchOutputDeviceTensor();
|
||||
if (memory_alloc_list_.size() > 0) {
|
||||
SendMemoryAllocReq(context);
|
||||
|
@ -181,30 +186,30 @@ void FreeMemory(const std::vector<DeviceTensor *> &free_list, const DeviceContex
|
|||
void KernelActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
|
||||
running_dependent_msg_num_ = 1;
|
||||
if (strategy_ == GraphExecutionStrategy::kPipeline) {
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &memory_alloc_list_, device_context_, context,
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &memory_alloc_list_, device_contexts_[0], context,
|
||||
GetAID());
|
||||
} else {
|
||||
AllocateMemory(memory_alloc_list_, device_context_);
|
||||
AllocateMemory(memory_alloc_list_, device_contexts_[0]);
|
||||
}
|
||||
}
|
||||
|
||||
void KernelActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
|
||||
if (strategy_ == GraphExecutionStrategy::kPipeline) {
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &memory_free_list_, device_context_, context);
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &memory_free_list_, device_contexts_[0], context);
|
||||
} else {
|
||||
FreeMemory(memory_free_list_, device_context_);
|
||||
FreeMemory(memory_free_list_, device_contexts_[0]);
|
||||
}
|
||||
}
|
||||
|
||||
void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
MS_EXCEPTION_IF_NULL(kernel_);
|
||||
MS_EXCEPTION_IF_NULL(device_context_);
|
||||
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
|
||||
PreLaunchKernel(context);
|
||||
|
||||
try {
|
||||
auto ret = device_context_->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_,
|
||||
launch_info_.outputs_, is_dynamic_shape_);
|
||||
auto ret = device_contexts_[0]->LaunchKernel(kernel_, launch_info_.inputs_, launch_info_.workspaces_,
|
||||
launch_info_.outputs_, is_dynamic_shape_);
|
||||
if (!ret) {
|
||||
std::string error_info = "Launch kernel failed: " + kernel_->fullname_with_scope();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
|
||||
|
@ -226,7 +231,7 @@ void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
|
|||
|
||||
void KernelActor::SendDebugReq(OpContext<DeviceTensor> *const context) {
|
||||
running_dependent_msg_num_ = 1;
|
||||
Async(*debug_aid_, &DebugActor::Debug, kernel_, &launch_info_, device_context_, context, &GetAID());
|
||||
Async(*debug_aid_, &DebugActor::Debug, kernel_, &launch_info_, device_contexts_[0], context, &GetAID());
|
||||
}
|
||||
|
||||
void KernelActor::OnDebugFinish(OpContext<DeviceTensor> *context) {
|
||||
|
@ -234,30 +239,6 @@ void KernelActor::OnDebugFinish(OpContext<DeviceTensor> *context) {
|
|||
PostLaunchKernel(context);
|
||||
}
|
||||
|
||||
bool KernelActor::CheckLaunchCondition(OpContext<DeviceTensor> *const context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_datas_num_ != 0) {
|
||||
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
|
||||
if (data_iter == input_op_datas_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (data_iter->second.size() != input_datas_num_) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (input_controls_num_ != 0) {
|
||||
const auto &control_iter = input_op_controls_.find(context->sequential_num_);
|
||||
if (control_iter == input_op_controls_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (control_iter->second.size() != input_controls_num_) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void KernelActor::PushInputDeviceTensor(const std::vector<TensorPtr> *input_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(input_tensors);
|
||||
if (input_tensors->size() != real_input_num_) {
|
||||
|
@ -279,24 +260,25 @@ void KernelActor::PushInputDeviceTensor(const std::vector<TensorPtr> *input_tens
|
|||
void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
|
||||
OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(input_data);
|
||||
if ((input_data->data_ == nullptr) || (input_data->data_->DeviceType() == device_context_->GetDeviceAddressType())) {
|
||||
if ((input_data->data_ == nullptr) ||
|
||||
(input_data->data_->DeviceType() == device_contexts_[0]->GetDeviceAddressType())) {
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Copy from device type: " << input_data->data_->DeviceType()
|
||||
<< " to device type: " << device_context_->GetDeviceAddressType() << " in " << GetAID().Name();
|
||||
<< " to device type: " << device_contexts_[0]->GetDeviceAddressType() << " in " << GetAID().Name();
|
||||
if (copy_input_device_tensors_[input_data->index_] == nullptr) {
|
||||
copy_input_device_tensors_[input_data->index_] = device_context_->CreateDeviceAddress(
|
||||
copy_input_device_tensors_[input_data->index_] = device_contexts_[0]->CreateDeviceAddress(
|
||||
nullptr, input_data->data_->GetSize(), input_data->data_->format(), input_data->data_->type_id());
|
||||
}
|
||||
// Dynamic shape need update size.
|
||||
copy_input_device_tensors_[input_data->index_]->SetSize(input_data->data_->GetSize());
|
||||
|
||||
if (copy_input_device_tensors_[input_data->index_]->GetPtr() == nullptr) {
|
||||
if (!device_context_->AllocateMemory(copy_input_device_tensors_[input_data->index_].get(),
|
||||
copy_input_device_tensors_[input_data->index_]->GetSize())) {
|
||||
if (!device_contexts_[0]->AllocateMemory(copy_input_device_tensors_[input_data->index_].get(),
|
||||
copy_input_device_tensors_[input_data->index_]->GetSize())) {
|
||||
std::string error_info =
|
||||
"Device(id:" + std::to_string(device_context_->device_context_key().device_id_) +
|
||||
"Device(id:" + std::to_string(device_contexts_[0]->device_context_key().device_id_) +
|
||||
") memory isn't enough and alloc failed, actor name: " + GetAID().Name() +
|
||||
", alloc size: " + std::to_string(copy_input_device_tensors_[input_data->index_]->GetSize());
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
|
@ -315,7 +297,7 @@ void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
|
|||
|
||||
void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
MS_EXCEPTION_IF_NULL(device_context_);
|
||||
MS_EXCEPTION_IF_NULL(device_contexts_[0]);
|
||||
|
||||
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
|
||||
if (data_iter != input_op_datas_.end()) {
|
||||
|
@ -330,12 +312,12 @@ void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context)
|
|||
}
|
||||
|
||||
for (auto &device_tensor_store_key : device_tensor_store_keys_) {
|
||||
auto device_tensor =
|
||||
DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second, device_context_->GetDeviceAddressType());
|
||||
auto device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get(),
|
||||
device_contexts_[0]->GetDeviceAddressType());
|
||||
if (device_tensor == nullptr) {
|
||||
std::string error_info =
|
||||
GetAID().Name() + " get device tensor store failed: " + device_tensor_store_key.second->fullname_with_scope() +
|
||||
", device type:" + std::to_string(static_cast<int>(device_context_->GetDeviceAddressType()));
|
||||
", device type:" + std::to_string(static_cast<int>(device_contexts_[0]->GetDeviceAddressType()));
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
|
||||
}
|
||||
if (input_device_tensors_[device_tensor_store_key.first] != device_tensor) {
|
||||
|
@ -439,8 +421,8 @@ void KernelActor::SendOutput(OpContext<DeviceTensor> *const context) const {
|
|||
|
||||
// 4.Send recorder info.
|
||||
if (recorder_aid_ != nullptr) {
|
||||
Async(*recorder_aid_, &RecorderActor::RecordInfo, kernel_->fullname_with_scope(), &launch_info_, device_context_,
|
||||
context);
|
||||
Async(*recorder_aid_, &RecorderActor::RecordInfo, kernel_->fullname_with_scope(), &launch_info_,
|
||||
device_contexts_[0], context);
|
||||
}
|
||||
|
||||
// No output.
|
||||
|
@ -449,28 +431,5 @@ void KernelActor::SendOutput(OpContext<DeviceTensor> *const context) const {
|
|||
SET_OPCONTEXT_SUCCESS_RET((*context));
|
||||
}
|
||||
}
|
||||
|
||||
void KernelActor::EraseInput(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_datas_num_ != 0) {
|
||||
auto ret = input_op_datas_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input data failed: " + GetAID().Name();
|
||||
// The sequential num may be invalid, can't set the promise value of context.
|
||||
MS_LOG(ERROR) << error_info << ", sequential_num: " << context->sequential_num_;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (input_controls_num_ != 0) {
|
||||
auto ret = input_op_controls_.erase(context->sequential_num_);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input controls failed: " + GetAID().Name();
|
||||
// The sequential num may be invalid, can't set the promise value of context.
|
||||
MS_LOG(ERROR) << error_info << ", sequential_num: " << context->sequential_num_;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -39,30 +39,24 @@ using mindspore::kernel::KernelLaunchInfo;
|
|||
using mindspore::tensor::TensorPtr;
|
||||
|
||||
// The kernel actor is used to receive the device tensors and control info to luanch kernel.
|
||||
// The processing flow is RunOpData/RunOpControl -> CheckLaunchCondition -> SendMemoryAllocReq
|
||||
// The processing flow is RunOpData/RunOpControl -> CheckRunningCondition -> SendMemoryAllocReq
|
||||
// -> OnMemoryAllocFinish -> LaunchKernel -> SendMemoryFreeReq -> SendOutput.
|
||||
class KernelActor : public DebugAwareActor {
|
||||
public:
|
||||
KernelActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context,
|
||||
const AID memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
|
||||
const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
|
||||
GraphExecutionStrategy strategy)
|
||||
: DebugAwareActor(name),
|
||||
: DebugAwareActor(name, recorder_aid, memory_manager_aid, debug_aid),
|
||||
kernel_(kernel),
|
||||
kernel_info_(nullptr),
|
||||
is_dynamic_shape_(false),
|
||||
device_context_(device_context),
|
||||
memory_manager_aid_(memory_manager_aid),
|
||||
debug_aid_(debug_aid),
|
||||
recorder_aid_(recorder_aid),
|
||||
input_datas_num_(0),
|
||||
input_controls_num_(0),
|
||||
real_input_num_(0),
|
||||
running_dependent_msg_num_(1),
|
||||
strategy_(strategy) {}
|
||||
strategy_(strategy) {
|
||||
(void)device_contexts_.emplace_back(device_context);
|
||||
}
|
||||
~KernelActor() override = default;
|
||||
|
||||
void Init() override;
|
||||
bool IsActive(int msg_num) override { return msg_num >= running_dependent_msg_num_ ? true : false; }
|
||||
|
||||
// The kernel actor run when receive the input data.
|
||||
void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;
|
||||
|
@ -86,8 +80,6 @@ class KernelActor : public DebugAwareActor {
|
|||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
// Check whether satisfy the condition for launch.
|
||||
bool CheckLaunchCondition(OpContext<DeviceTensor> *const context) const;
|
||||
// Fetch the device tensor for launch.
|
||||
void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context);
|
||||
void FetchOutputDeviceTensor();
|
||||
|
@ -102,45 +94,20 @@ class KernelActor : public DebugAwareActor {
|
|||
|
||||
// Send output data and output controls when finish kernel launch.
|
||||
void SendOutput(OpContext<DeviceTensor> *const context) const;
|
||||
// Erase input data and input controls when finish kernel launch.
|
||||
void EraseInput(OpContext<DeviceTensor> *const context);
|
||||
|
||||
// The info of kernel.
|
||||
CNodePtr kernel_;
|
||||
KernelInfo *kernel_info_;
|
||||
bool is_dynamic_shape_;
|
||||
|
||||
// The device interface of kernel launch.
|
||||
const DeviceContext *device_context_;
|
||||
|
||||
// The id of memory manager actor. Send message to it for alloc and free memory during the kernel launch.
|
||||
const AID memory_manager_aid_;
|
||||
// The id of debug actor. Send message to it for debug after the kernel launch.
|
||||
const AID *debug_aid_;
|
||||
// The id of recorder actor. Send message to it for recording kernel info after the kernel launch.
|
||||
const AID *recorder_aid_;
|
||||
|
||||
// The dependent input data number.
|
||||
size_t input_datas_num_;
|
||||
// The dependent input controls number.
|
||||
size_t input_controls_num_;
|
||||
// The real input number of kernel launch.
|
||||
size_t real_input_num_;
|
||||
// The dependent messages number of actor running.
|
||||
int running_dependent_msg_num_;
|
||||
|
||||
// The execution strategy of kernel actor.
|
||||
// In pipeline mode, kernel actor executes asynchronously.
|
||||
// In step mode, kernel actor executes synchronously.
|
||||
GraphExecutionStrategy strategy_{GraphExecutionStrategy::kPipeline};
|
||||
|
||||
// The dependent input actors.
|
||||
std::vector<AID> input_data_arrow_aids_;
|
||||
std::vector<AID> input_control_arrow_aids_;
|
||||
|
||||
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
|
||||
std::vector<std::pair<size_t, AnfNode *>> device_tensor_store_keys_;
|
||||
|
||||
// The device tensors for launch.
|
||||
std::vector<DeviceTensor *> input_device_tensors_;
|
||||
std::vector<DeviceTensor *> output_device_tensors_;
|
||||
|
@ -160,9 +127,6 @@ class KernelActor : public DebugAwareActor {
|
|||
// The kernel launch info is fetched by the device tensors.
|
||||
KernelLaunchInfo launch_info_;
|
||||
|
||||
// The output result arrows of graph output.
|
||||
std::vector<DataArrowPtr> output_result_arrows_;
|
||||
|
||||
// Cache unique output data by output index to modify the output data effectively.
|
||||
std::vector<std::vector<OpDataUniquePtr<DeviceTensor>>> output_data_by_output_index_;
|
||||
// The output_data_ corresponds to the output_data_arrows_ one by one.
|
||||
|
|
|
@ -86,7 +86,7 @@ void LoopCountActor::RunOpControl(AID *const input_control, OpContext<DeviceTens
|
|||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto sequential_num = context->sequential_num_;
|
||||
(void)input_op_controls_[sequential_num].emplace_back(input_control);
|
||||
if (CheckLoopCountIncreaseCondition(context)) {
|
||||
if (CheckRunningCondition(context)) {
|
||||
IncreaseLoopCount(context);
|
||||
}
|
||||
}
|
||||
|
@ -102,12 +102,7 @@ void LoopCountActor::OnDebugFinish(OpContext<DeviceTensor> *const context) {
|
|||
|
||||
void LoopCountActor::IncreaseLoopCount(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto sequential_num = context->sequential_num_;
|
||||
auto ret = input_op_controls_.erase(sequential_num);
|
||||
if (ret == 0) {
|
||||
std::string error_info = "Erase input controls failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
EraseInput(context);
|
||||
|
||||
total_running_count_++;
|
||||
current_count_++;
|
||||
|
@ -165,12 +160,5 @@ void LoopCountActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context)
|
|||
Async(kernel_aid, &KernelActor::RunOpControl, source_aid, context);
|
||||
}
|
||||
}
|
||||
|
||||
bool LoopCountActor::CheckLoopCountIncreaseCondition(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto sequential_num = context->sequential_num_;
|
||||
|
||||
return input_op_controls_[sequential_num].size() == input_controls_num_;
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,16 +34,12 @@ namespace runtime {
|
|||
// and decide whether to loop execution by loop count.
|
||||
class LoopCountActor : public DebugAwareActor {
|
||||
public:
|
||||
LoopCountActor(std::string name, size_t loop_count, const AID memory_manager_aid, const AID *debug_aid,
|
||||
LoopCountActor(const std::string &name, size_t loop_count, const AID &memory_manager_aid, const AID *debug_aid,
|
||||
const AID *recorder_aid)
|
||||
: DebugAwareActor(name),
|
||||
: DebugAwareActor(name, recorder_aid, memory_manager_aid, debug_aid),
|
||||
loop_count_(loop_count),
|
||||
current_count_(0),
|
||||
total_running_count_(0),
|
||||
input_controls_num_(0),
|
||||
memory_manager_aid_(memory_manager_aid),
|
||||
debug_aid_(debug_aid),
|
||||
recorder_aid_(recorder_aid) {}
|
||||
total_running_count_(0) {}
|
||||
|
||||
~LoopCountActor() override = default;
|
||||
|
||||
|
@ -68,30 +64,17 @@ class LoopCountActor : public DebugAwareActor {
|
|||
void IncreaseLoopCount(OpContext<DeviceTensor> *const context);
|
||||
void SendOutput(OpContext<DeviceTensor> *const context);
|
||||
|
||||
bool CheckLoopCountIncreaseCondition(OpContext<DeviceTensor> *const context);
|
||||
// The loop count is constant, the current count is increased after each step running finished.
|
||||
size_t loop_count_;
|
||||
size_t current_count_;
|
||||
// The total running count represents the toal step running count.
|
||||
size_t total_running_count_;
|
||||
|
||||
// The dependent input controls number.
|
||||
// In the multi-branch output scenario of the control flow, the control of each branch needs to be recorded
|
||||
// separately with the branch id as the key. When the output has only one branch, the branch id is 0.
|
||||
size_t input_controls_num_;
|
||||
|
||||
// The output controls contain the data source actors and the no input kernel actors and output actor.
|
||||
std::vector<AID> data_source_aids_;
|
||||
std::vector<AID> no_input_kernel_aids_;
|
||||
AID output_aid_;
|
||||
|
||||
// The id of memory manager actor. Send message to it for alloc continuous memory before next step running.
|
||||
const AID memory_manager_aid_;
|
||||
// The id of debug actor. Send message to it for debug before loop count actor exits.
|
||||
const AID *debug_aid_;
|
||||
// The id of recorder actor. Send message to it for clearing recorder info before loop count actor exits.
|
||||
const AID *recorder_aid_;
|
||||
|
||||
// The nodes need continuous memory, which must allocate in the begin of step running. The first bool of pair
|
||||
// expresses the inputs of node need continuous memory, the second bool of pair expresses the outputs of node need
|
||||
// continuous memory.
|
||||
|
@ -100,7 +83,6 @@ class LoopCountActor : public DebugAwareActor {
|
|||
std::vector<std::vector<DeviceTensorPtr>> continuous_memory_alloc_list_list_;
|
||||
std::vector<std::vector<size_t>> size_list_list_;
|
||||
std::vector<size_t> total_size_list_;
|
||||
std::vector<const DeviceContext *> device_contexts_;
|
||||
};
|
||||
|
||||
using LoopCountActorPtr = std::shared_ptr<LoopCountActor>;
|
||||
|
|
|
@ -19,21 +19,27 @@
|
|||
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include "mindrt/include/actor/op_actor.h"
|
||||
#include "runtime/framework/actor/abstract_actor.h"
|
||||
#include "runtime/framework/device_tensor_store.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
// The actor represents a set of common memory related operations of actor.
|
||||
class MemoryAwareActor : public OpActor<DeviceTensor> {
|
||||
class MemoryAwareActor : public AbstractActor {
|
||||
public:
|
||||
explicit MemoryAwareActor(std::string name) : OpActor(name) {}
|
||||
explicit MemoryAwareActor(const std::string &name, const AID *recorder_aid, const AID &memory_manager_aid)
|
||||
: AbstractActor(name, recorder_aid), memory_manager_aid_(memory_manager_aid) {}
|
||||
virtual ~MemoryAwareActor() = default;
|
||||
|
||||
virtual void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {}
|
||||
virtual void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {}
|
||||
virtual void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {}
|
||||
|
||||
protected:
|
||||
friend class GraphScheduler;
|
||||
|
||||
// The id of memory manager actor. Send message to it for alloc and free memory.
|
||||
const AID memory_manager_aid_;
|
||||
};
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "runtime/framework/control_node_parser.h"
|
||||
#include "runtime/framework/device_tensor_store.h"
|
||||
#include "runtime/framework/actor/actor_common.h"
|
||||
#include "runtime/framework/actor/abstract_actor.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/tensor.h"
|
||||
|
@ -37,16 +38,15 @@ using mindspore::session::KernelWithIndex;
|
|||
using mindspore::tensor::TensorPtr;
|
||||
|
||||
// The output actor is used to receive the output result of actor which represents the graph output.
|
||||
class OutputActor : public OpActor<DeviceTensor> {
|
||||
class OutputActor : public AbstractActor {
|
||||
public:
|
||||
OutputActor(std::string name, size_t loop_count, size_t outputs_num, bool need_loop_count)
|
||||
: OpActor(name),
|
||||
: AbstractActor(name, nullptr),
|
||||
loop_count_(loop_count),
|
||||
current_count_(0),
|
||||
outputs_num_(outputs_num),
|
||||
current_outputs_num_(0),
|
||||
need_loop_count_(need_loop_count),
|
||||
running_dependent_msg_num_(1) {
|
||||
need_loop_count_(need_loop_count) {
|
||||
outputs_.resize(outputs_num);
|
||||
output_nodes_.resize(outputs_num);
|
||||
device_contexts_.resize(outputs_num);
|
||||
|
@ -54,7 +54,6 @@ class OutputActor : public OpActor<DeviceTensor> {
|
|||
~OutputActor() override = default;
|
||||
|
||||
void Init() override;
|
||||
bool IsActive(int msg_num) override { return msg_num >= running_dependent_msg_num_ ? true : false; }
|
||||
|
||||
// The output actor collects loop count when receive the input control of loop count actor.
|
||||
void CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *const context);
|
||||
|
@ -80,15 +79,9 @@ class OutputActor : public OpActor<DeviceTensor> {
|
|||
// The outputs.
|
||||
std::vector<TensorPtr> outputs_;
|
||||
std::vector<KernelWithIndex> output_nodes_;
|
||||
std::vector<const DeviceContext *> device_contexts_;
|
||||
size_t outputs_num_;
|
||||
size_t current_outputs_num_;
|
||||
bool need_loop_count_;
|
||||
|
||||
// The dependent messages number of actor running.
|
||||
int running_dependent_msg_num_;
|
||||
|
||||
std::vector<std::pair<size_t, AnfNodePtr>> device_tensor_store_keys_;
|
||||
};
|
||||
|
||||
using OutputActorPtr = std::shared_ptr<OutputActor>;
|
||||
|
|
|
@ -1361,7 +1361,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *const to_actor, const GraphCompi
|
|||
AnfAlgo::VisitKernelWithReturnType(kernel_with_index.first, SizeToInt(kernel_with_index.second));
|
||||
if (HasAbstractRef(real_front_node_with_index.first)) {
|
||||
(void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second,
|
||||
real_front_node_with_index.first.get());
|
||||
real_front_node_with_index.first);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1388,7 +1388,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *const to_actor, const GraphCompi
|
|||
auto actor_name = func_graph->ToString();
|
||||
const auto &from_actor = dynamic_cast<GatherActor *>(FetchActor(actor_name));
|
||||
if (HasAbstractRef(from_kernel)) {
|
||||
(void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_node.get());
|
||||
(void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_node);
|
||||
return;
|
||||
}
|
||||
LinkDataArrowForGatherActor(from_actor, to_actor, {front_node, 0}, to_kernel_with_input_idx);
|
||||
|
@ -1409,8 +1409,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *const to_actor, const GraphCompi
|
|||
to_kernel_with_input_idx);
|
||||
} else if (IsPersistentDeviceTensor(from_kernel)) {
|
||||
const auto device_tensor_store_key = FetchFrontNodeByBackendNode(from_kernel, graph);
|
||||
(void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second,
|
||||
device_tensor_store_key.get());
|
||||
(void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, device_tensor_store_key);
|
||||
} else {
|
||||
// May exist the from kernel that no need link in the pynative mode.
|
||||
MS_LOG(DEBUG) << "Invalid from kernel: " << from_kernel->fullname_with_scope();
|
||||
|
@ -1437,7 +1436,7 @@ void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &interna
|
|||
return;
|
||||
}
|
||||
if (IsPersistentDeviceTensor(front_output_node)) {
|
||||
(void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_output_node.get());
|
||||
(void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_output_node);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1486,7 +1485,7 @@ void GraphScheduler::LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *c
|
|||
auto from_output_index = from_kernel_with_output_idx.second;
|
||||
auto to_input_index = to_kernel_with_input_idx.second;
|
||||
|
||||
if (IsNeedInsertCopyActor(from_actor->device_context_, to_actor->device_context_)) {
|
||||
if (IsNeedInsertCopyActor(from_actor->device_contexts_[0], to_actor->device_contexts_[0])) {
|
||||
LinkDataArrowForCopyActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else {
|
||||
auto to_aid = to_actor->GetAID();
|
||||
|
@ -1514,7 +1513,7 @@ void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *const
|
|||
|
||||
// Get the position of from kernel in the data source actor.
|
||||
auto position = from_actor->FetchDataNodePosition(from_kernel);
|
||||
if (IsNeedInsertCopyActor(from_actor->device_contexts_[position], to_actor->device_context_)) {
|
||||
if (IsNeedInsertCopyActor(from_actor->device_contexts_[position], to_actor->device_contexts_[0])) {
|
||||
LinkDataArrowForCopyActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else {
|
||||
auto to_aid = to_actor->GetAID();
|
||||
|
@ -1554,7 +1553,7 @@ void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, Kernel
|
|||
auto from_output_index = from_kernel_with_output_idx.second;
|
||||
auto to_input_index = to_kernel_with_input_idx.second;
|
||||
|
||||
if (IsNeedInsertCopyActor(from_actor->device_context_, to_actor->device_context_)) {
|
||||
if (IsNeedInsertCopyActor(from_actor->device_contexts_[0], to_actor->device_contexts_[0])) {
|
||||
LinkDataArrowForCopyActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else {
|
||||
auto to_aid = to_actor->GetAID();
|
||||
|
@ -1575,7 +1574,7 @@ void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *const from
|
|||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
auto from_kernel = from_kernel_with_output_idx.first;
|
||||
MS_EXCEPTION_IF_NULL(from_kernel);
|
||||
auto to_device_context = to_actor->device_context_;
|
||||
auto to_device_context = to_actor->device_contexts_[0];
|
||||
MS_EXCEPTION_IF_NULL(to_device_context);
|
||||
auto from_output_index = from_kernel_with_output_idx.second;
|
||||
auto to_input_index = to_kernel_with_input_idx.second;
|
||||
|
@ -1599,12 +1598,12 @@ void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *const from
|
|||
if (IsDeviceQueueDSActor(from_kernel)) {
|
||||
auto real_from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(real_from_actor);
|
||||
from_device_context = real_from_actor->device_context_;
|
||||
from_device_context = real_from_actor->device_contexts_[0];
|
||||
(void)real_from_actor->output_data_arrows_.emplace_back(op_arrow_to_copy);
|
||||
} else if (IsKernelActor(from_kernel)) {
|
||||
auto real_from_actor = dynamic_cast<KernelActor *>(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(real_from_actor);
|
||||
from_device_context = real_from_actor->device_context_;
|
||||
from_device_context = real_from_actor->device_contexts_[0];
|
||||
(void)real_from_actor->output_data_arrows_.emplace_back(op_arrow_to_copy);
|
||||
} else if (IsHostQueueDSActor(from_kernel)) {
|
||||
auto real_from_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
|
||||
|
@ -1629,8 +1628,8 @@ void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *const from
|
|||
copy_actor->output_ = to_device_context->CreateDeviceAddress(
|
||||
nullptr, input_sizes[to_input_index], from_device_tensor->format(), from_device_tensor->type_id());
|
||||
MS_EXCEPTION_IF_NULL(from_device_context);
|
||||
copy_actor->input_device_context_ = from_device_context;
|
||||
copy_actor->output_device_context_ = to_device_context;
|
||||
(void)copy_actor->device_contexts_.emplace_back(from_device_context);
|
||||
(void)copy_actor->device_contexts_.emplace_back(to_device_context);
|
||||
|
||||
// Update the reference count of device tensor.
|
||||
UpdateRefCount(from_device_tensor.get());
|
||||
|
@ -2060,7 +2059,7 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ke
|
|||
for (auto &kernel_actor : auto_monad_actors) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
for (auto &device_tensor_store_key : kernel_actor->device_tensor_store_keys_) {
|
||||
auto device_tensors = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second);
|
||||
auto device_tensors = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get());
|
||||
if (device_tensors.size() < kNeedUpdateDeviceTensorStoreNum) {
|
||||
continue;
|
||||
}
|
||||
|
@ -2077,9 +2076,9 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ke
|
|||
InsertActor(copy_actor.get());
|
||||
|
||||
// Set the member of the copy actor.
|
||||
copy_actor->device_tensor_store_key_ = std::pair<size_t, AnfNode *>(0, device_tensor_store_key.second);
|
||||
auto input_device_context = kernel_actor->device_context_;
|
||||
copy_actor->input_device_context_ = input_device_context;
|
||||
(void)copy_actor->device_tensor_store_keys_.emplace_back(0, device_tensor_store_key.second);
|
||||
auto input_device_context = kernel_actor->device_contexts_[0];
|
||||
(void)copy_actor->device_contexts_.emplace_back(input_device_context);
|
||||
auto another_device_tensor = (device_tensors[0]->DeviceType() == input_device_context->GetDeviceAddressType())
|
||||
? device_tensors[1]
|
||||
: device_tensors[0];
|
||||
|
@ -2088,7 +2087,7 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ke
|
|||
const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
|
||||
{device::kDeviceTypeToName.at(another_device_type), input_device_context->device_context_key().device_id_});
|
||||
MS_EXCEPTION_IF_NULL(another_device_context);
|
||||
copy_actor->output_device_context_ = another_device_context;
|
||||
(void)copy_actor->device_contexts_.emplace_back(another_device_context);
|
||||
|
||||
MS_LOG(INFO) << "The kernel actor: " << kernel_actor->GetAID().Name()
|
||||
<< "has control arrows number:" << kernel_actor->output_control_arrows_.size();
|
||||
|
@ -2662,7 +2661,7 @@ bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionSt
|
|||
|
||||
const size_t kCopyActorInputDataNum = 1;
|
||||
auto input_data_num = copy_actor->input_datas_num_;
|
||||
size_t device_tensor_store_num = (copy_actor->device_tensor_store_key_.second == nullptr) ? 0 : 1;
|
||||
size_t device_tensor_store_num = copy_actor->device_tensor_store_keys_.size();
|
||||
if (input_data_num + device_tensor_store_num != kCopyActorInputDataNum) {
|
||||
MS_LOG(ERROR) << "The input building of " << copy_actor->GetAID().Name()
|
||||
<< " is wrong, input data num: " << input_data_num
|
||||
|
@ -2838,78 +2837,131 @@ void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInf
|
|||
ofs << "[Device tensor stores]\n";
|
||||
DumpDeviceTensorStore(graph_compiler_info, ofs);
|
||||
|
||||
ofs << "\n\n[Data source actors]\n";
|
||||
ofs << "\n\n[Data source actors:" << actor_set->data_source_actors_.size() << "]\n";
|
||||
for (const auto &data_source_actor : actor_set->data_source_actors_) {
|
||||
DumpDSActor(data_source_actor.get(), ofs);
|
||||
}
|
||||
|
||||
ofs << "\n\n[Kernel actors]\n";
|
||||
ofs << "\n\n[Kernel actors:" << actor_set->kernel_actors_.size() << "]\n";
|
||||
for (const auto &kernel_actor : actor_set->kernel_actors_) {
|
||||
DumpKernelActor(kernel_actor.get(), ofs);
|
||||
}
|
||||
|
||||
ofs << "\n\n[No input kernel actors]\n";
|
||||
ofs << "\n\n[No input kernel actors:" << actor_set->no_input_kernel_actors_.size() << "]\n";
|
||||
for (const auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
|
||||
DumpKernelActor(no_input_kernel_actor.get(), ofs);
|
||||
}
|
||||
|
||||
ofs << "\n\n[Copy actors]\n";
|
||||
ofs << "\n\n[Copy actors:" << actor_set->copy_actors_.size() << "]\n";
|
||||
for (const auto ©_actor : actor_set->copy_actors_) {
|
||||
DumpCopyActor(copy_actor.get(), ofs);
|
||||
}
|
||||
|
||||
ofs << "\n\n[Gather actors]\n";
|
||||
ofs << "\n\n[Gather actors:" << actor_set->gather_actors_.size() << "]\n";
|
||||
for (const auto &gather_actor : actor_set->gather_actors_) {
|
||||
DumpGatherActor(gather_actor.get(), ofs);
|
||||
}
|
||||
|
||||
ofs << "\n\n[Switch actors]\n";
|
||||
ofs << "\n\n[Switch actors:" << actor_set->switch_actors_.size() << "]\n";
|
||||
for (const auto &switch_actor : actor_set->switch_actors_) {
|
||||
DumpSwitchActor(switch_actor.get(), ofs);
|
||||
}
|
||||
|
||||
ofs << "\n\n[Loop count actor]\n";
|
||||
const auto &loop_count_actor = actor_set->loop_count_actor_;
|
||||
ofs << "\n\n[Loop count actor:" << (loop_count_actor != nullptr ? 1 : 0) << "]\n";
|
||||
if (loop_count_actor != nullptr) {
|
||||
DumpLoopCountActor(loop_count_actor.get(), ofs);
|
||||
}
|
||||
|
||||
ofs << "\n\n[Output actor]\n";
|
||||
const auto &output_actor = actor_set->output_actor_;
|
||||
ofs << "\n\n[Output actor:" << (output_actor != nullptr ? 1 : 0) << "]\n";
|
||||
if (output_actor != nullptr) {
|
||||
DumpOutputActor(output_actor.get(), ofs);
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::DumpBaseActor(const OpActor<DeviceTensor> *actor, std::ofstream &ofs) const {
|
||||
void GraphScheduler::DumpAbstractActor(const AbstractActor *actor, std::ofstream &ofs) const {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\t\tdevice_contexts_num:" << actor->device_contexts_.size()
|
||||
<< "\tdevice_tensor_store_keys_num:" << actor->device_tensor_store_keys_.size()
|
||||
<< "\tinput_data_arrow_actors_num:" << actor->input_datas_num_
|
||||
<< "\tinput_control_arrow_actors_num:" << actor->input_controls_num_ << "\n";
|
||||
ofs << "\t\toutput_data_arrows_num:" << actor->output_data_arrows_.size()
|
||||
<< "\toutput_control_arrows_num:" << actor->output_control_arrows_.size()
|
||||
<< "\toutput_result_arrows_num:" << actor->output_result_arrows_.size() << "\n";
|
||||
|
||||
if (actor->device_contexts_.size() > 0) {
|
||||
ofs << "\t\tdevice_contexts:" << actor->device_contexts_.size() << "\n ";
|
||||
for (const auto &device_context : actor->device_contexts_) {
|
||||
if (device_context == nullptr) {
|
||||
ofs << "\t\t\tdevice_context:" << device_context << "\n";
|
||||
continue;
|
||||
}
|
||||
ofs << "\t\t\tdevice_context:" << device_context->device_context_key().ToString() << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
if (actor->device_tensor_store_keys_.size() > 0) {
|
||||
ofs << "\t\tdevice_tensor_store_keys:" << actor->device_tensor_store_keys_.size() << "\n ";
|
||||
for (const auto &device_tensor_store_key : actor->device_tensor_store_keys_) {
|
||||
MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
|
||||
ofs << "\t\t\tto_input_index:" << device_tensor_store_key.first
|
||||
<< "\tfrom_node_name:" << device_tensor_store_key.second->fullname_with_scope() << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
if (actor->input_data_arrow_aids_.size() > 0) {
|
||||
ofs << "\t\tinput_data_arrow_actors:" << actor->input_data_arrow_aids_.size() << "\n ";
|
||||
for (const auto &input_data_arrow_aid : actor->input_data_arrow_aids_) {
|
||||
ofs << "\t\t\tfrom_actor_name:" << input_data_arrow_aid.Name() << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
if (actor->input_control_arrow_aids_.size() > 0) {
|
||||
ofs << "\t\tinput_control_arrow_actors:" << actor->input_control_arrow_aids_.size() << "\n ";
|
||||
for (const auto &input_control_arrow_aid : actor->input_control_arrow_aids_) {
|
||||
ofs << "\t\t\tfrom_actor_name:" << input_control_arrow_aid.Name() << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
const auto &output_data_arrows = actor->output_data_arrows();
|
||||
ofs << "\t\toutput_data_arrows:" << output_data_arrows.size() << "\n ";
|
||||
for (const auto &data_arrow : output_data_arrows) {
|
||||
MS_EXCEPTION_IF_NULL(data_arrow);
|
||||
ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_
|
||||
<< "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
|
||||
<< "\n";
|
||||
if (output_data_arrows.size() > 0) {
|
||||
ofs << "\t\toutput_data_arrows:" << output_data_arrows.size() << "\n ";
|
||||
for (const auto &data_arrow : output_data_arrows) {
|
||||
MS_EXCEPTION_IF_NULL(data_arrow);
|
||||
ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_
|
||||
<< "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
|
||||
<< "\n";
|
||||
}
|
||||
}
|
||||
|
||||
const auto &output_control_arrows = actor->output_control_arrows();
|
||||
ofs << "\t\toutput_control_arrows:" << output_control_arrows.size() << "\n ";
|
||||
for (const auto &aid : output_control_arrows) {
|
||||
ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
|
||||
if (output_control_arrows.size() > 0) {
|
||||
ofs << "\t\toutput_control_arrows:" << output_control_arrows.size() << "\n ";
|
||||
for (const auto &aid : output_control_arrows) {
|
||||
ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
if (actor->output_result_arrows_.size() > 0) {
|
||||
ofs << "\t\toutput_result_arrows:" << actor->output_result_arrows_.size() << "\n ";
|
||||
for (const auto &result_arrow : actor->output_result_arrows_) {
|
||||
MS_EXCEPTION_IF_NULL(result_arrow);
|
||||
ofs << "\t\t\tfrom_output_index:" << result_arrow->from_output_index_
|
||||
<< "\tto_actor_name:" << result_arrow->to_op_id_.Name()
|
||||
<< "\toutput_node_position:" << result_arrow->to_input_index_ << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
const auto &actor_name = actor->GetAID().Name();
|
||||
ofs << "\tactor_name:" << actor_name << "\n";
|
||||
|
||||
if (actor_name.find("_DeviceDSActor") != string::npos) {
|
||||
// Dump the member info of device queue data source actor.
|
||||
const auto &device_queue_ds_actor = dynamic_cast<const DeviceQueueDataSourceActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(device_queue_ds_actor->device_context_);
|
||||
ofs << "\tactor_name:" << actor_name
|
||||
<< "\tdevice_context:" << device_queue_ds_actor->device_context_->device_context_key().ToString() << "\n";
|
||||
const auto &data_kernel = device_queue_ds_actor->data_kernel_;
|
||||
MS_EXCEPTION_IF_NULL(data_kernel);
|
||||
ofs << "\t\tdata_kernel_name:" << data_kernel->fullname_with_scope()
|
||||
|
@ -2923,7 +2975,6 @@ void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &of
|
|||
}
|
||||
} else if (actor_name.find("_HostDSActor") != string::npos) {
|
||||
// Dump the member info of host queue data source actor.
|
||||
ofs << "\tactor_name:" << actor_name << "\n";
|
||||
const auto &host_queue_ds_actor = dynamic_cast<const HostQueueDataSourceActor *>(actor);
|
||||
ofs << "\t\tdata_nodes:" << host_queue_ds_actor->data_nodes_.size() << "\n";
|
||||
for (size_t i = 0; i < host_queue_ds_actor->data_nodes_.size(); ++i) {
|
||||
|
@ -2933,27 +2984,18 @@ void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &of
|
|||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
ofs << "\t\t\tnode_order_number:" << i << "\tnode_name:" << data_node->fullname_with_scope()
|
||||
<< "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
|
||||
<< "\toriginal_ref_count:" << device_tensor->original_ref_count()
|
||||
<< "\tdevice_context:" << host_queue_ds_actor->device_contexts_[i]->device_context_key().ToString() << "\n";
|
||||
<< "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
DumpBaseActor(actor, ofs);
|
||||
|
||||
ofs << "\t\toutput_result_arrows:" << actor->output_result_arrows_.size() << "\n ";
|
||||
for (const auto &result_arrow : actor->output_result_arrows_) {
|
||||
MS_EXCEPTION_IF_NULL(result_arrow);
|
||||
ofs << "\t\t\tfrom_output_index:" << result_arrow->from_output_index_
|
||||
<< "\tto_actor_name:" << result_arrow->to_op_id_.Name()
|
||||
<< "\toutput_node_position:" << result_arrow->to_input_index_ << "\n";
|
||||
}
|
||||
DumpAbstractActor(actor, ofs);
|
||||
ofs << "\n";
|
||||
}
|
||||
|
||||
void GraphScheduler::DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_
|
||||
<< "\tinput_controls_num:" << actor->input_controls_num_ << "\n";
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_ << "\n";
|
||||
DumpAbstractActor(actor, ofs);
|
||||
|
||||
ofs << "\t\toutput_control_arrows:" << (actor->data_source_aids_.size() + actor->no_input_kernel_aids_.size() + 1)
|
||||
<< "\n ";
|
||||
|
@ -2975,16 +3017,12 @@ void GraphScheduler::DumpLoopCountActor(const LoopCountActor *actor, std::ofstre
|
|||
|
||||
void GraphScheduler::DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
MS_EXCEPTION_IF_NULL(actor->device_context_);
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name()
|
||||
<< "\tdevice_context:" << actor->device_context_->device_context_key().ToString()
|
||||
<< "\tinput_data_num:" << actor->input_datas_num_ << "\tinput_controls_num:" << actor->input_controls_num_
|
||||
<< "\n";
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << "\n";
|
||||
|
||||
const auto &kernel = actor->kernel_;
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
ofs << "\t\tkernel_name:" << kernel->fullname_with_scope() << "\tinput_number:" << AnfAlgo::GetInputTensorNum(kernel)
|
||||
<< "\toutput_number:" << AnfAlgo::GetOutputTensorNum(kernel) << "\n";
|
||||
ofs << "\t\tkernel_name:" << kernel->fullname_with_scope() << "\tinputs_num:" << AnfAlgo::GetInputTensorNum(kernel)
|
||||
<< "\toutputs_num:" << AnfAlgo::GetOutputTensorNum(kernel) << "\n";
|
||||
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel); ++i) {
|
||||
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
|
@ -2992,22 +3030,7 @@ void GraphScheduler::DumpKernelActor(const KernelActor *actor, std::ofstream &of
|
|||
<< "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
|
||||
}
|
||||
|
||||
ofs << "\t\tdevice_tensor_stores:" << actor->device_tensor_store_keys_.size() << "\n ";
|
||||
for (const auto &device_tensor_store_key : actor->device_tensor_store_keys_) {
|
||||
MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
|
||||
ofs << "\t\t\tto_input_index:" << device_tensor_store_key.first
|
||||
<< "\tfrom_node_name:" << device_tensor_store_key.second->fullname_with_scope() << "\n";
|
||||
}
|
||||
|
||||
DumpBaseActor(actor, ofs);
|
||||
|
||||
ofs << "\t\toutput_result_arrows:" << actor->output_result_arrows_.size() << "\n ";
|
||||
for (const auto &result_arrow : actor->output_result_arrows_) {
|
||||
MS_EXCEPTION_IF_NULL(result_arrow);
|
||||
ofs << "\t\t\tfrom_output_index:" << result_arrow->from_output_index_
|
||||
<< "\tto_actor_name:" << result_arrow->to_op_id_.Name()
|
||||
<< "\toutput_node_position:" << result_arrow->to_input_index_ << "\n";
|
||||
}
|
||||
DumpAbstractActor(actor, ofs);
|
||||
ofs << "\n";
|
||||
}
|
||||
|
||||
|
@ -3015,33 +3038,12 @@ void GraphScheduler::DumpOutputActor(const OutputActor *actor, std::ofstream &of
|
|||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_
|
||||
<< "\toutputs_num:" << actor->outputs_num_ << "\n";
|
||||
|
||||
ofs << "\t\tdevice_tensor_store_keys:" << actor->device_tensor_store_keys_.size() << "\n ";
|
||||
for (const auto &device_tensor_store_key : actor->device_tensor_store_keys_) {
|
||||
MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
|
||||
ofs << "\t\t\toutput_node_position:" << device_tensor_store_key.first
|
||||
<< "\toutput_node_name:" << device_tensor_store_key.second->fullname_with_scope() << "\n";
|
||||
}
|
||||
|
||||
ofs << "\t\tdevice_contexts:" << actor->device_contexts_.size() << "\n ";
|
||||
for (const auto &device_context : actor->device_contexts_) {
|
||||
if (device_context == nullptr) {
|
||||
ofs << "\t\t\tdevice_context:" << device_context << "\n";
|
||||
continue;
|
||||
}
|
||||
ofs << "\t\t\tdevice_context:" << device_context->device_context_key().ToString() << "\n";
|
||||
}
|
||||
DumpAbstractActor(actor, ofs);
|
||||
}
|
||||
|
||||
void GraphScheduler::DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) const {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
MS_EXCEPTION_IF_NULL(actor->input_device_context_);
|
||||
MS_EXCEPTION_IF_NULL(actor->output_device_context_);
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name()
|
||||
<< "\tinput_device_context:" << actor->input_device_context_->device_context_key().ToString()
|
||||
<< "\toutput_device_context:" << actor->output_device_context_->device_context_key().ToString()
|
||||
<< "\tinput_data_num:" << actor->input_datas_num_ << "\tinput_controls_num:" << actor->input_controls_num_
|
||||
<< "\n";
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << "\n";
|
||||
|
||||
auto device_tensor = actor->output_;
|
||||
if (device_tensor != nullptr) {
|
||||
|
@ -3049,13 +3051,7 @@ void GraphScheduler::DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) c
|
|||
<< "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
|
||||
}
|
||||
|
||||
if (actor->device_tensor_store_key_.second != nullptr) {
|
||||
ofs << "\t\tdevice_tensor_stores:" << 1 << "\n ";
|
||||
ofs << "\t\t\tto_input_index:" << actor->device_tensor_store_key_.first
|
||||
<< "\tfrom_node_name:" << actor->device_tensor_store_key_.second->fullname_with_scope() << "\n";
|
||||
}
|
||||
|
||||
DumpBaseActor(actor, ofs);
|
||||
DumpAbstractActor(actor, ofs);
|
||||
ofs << "\n";
|
||||
}
|
||||
|
||||
|
@ -3144,6 +3140,7 @@ void GraphScheduler::DumpGatherActor(const GatherActor *actor, std::ofstream &of
|
|||
for (const auto &control_arrow : actor->output_control_arrows_) {
|
||||
ofs << "\t\t\tto_actor_name:" << control_arrow;
|
||||
}
|
||||
ofs << "\n";
|
||||
}
|
||||
|
||||
void GraphScheduler::DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) const {
|
||||
|
@ -3191,6 +3188,7 @@ void GraphScheduler::DumpSwitchActor(const SwitchActor *actor, std::ofstream &of
|
|||
ofs << "\t\t\t\t from index:" << arrow << '\n';
|
||||
}
|
||||
}
|
||||
ofs << "\n";
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -302,7 +302,7 @@ class GraphScheduler {
|
|||
|
||||
// Display the actor information of corresponding kernel graph.
|
||||
void DumpActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) const;
|
||||
void DumpBaseActor(const OpActor<DeviceTensor> *actor, std::ofstream &ofs) const;
|
||||
void DumpAbstractActor(const AbstractActor *actor, std::ofstream &ofs) const;
|
||||
void DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const;
|
||||
void DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const;
|
||||
void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const;
|
||||
|
|
Loading…
Reference in New Issue