!14294 add the impl of runtime actors

From: @limingqi107
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-31 18:03:19 +08:00 committed by Gitee
commit d2ee376b7b
22 changed files with 733 additions and 60 deletions

View File

@ -214,6 +214,7 @@ set(SUB_COMP
backend/kernel_compiler backend/kernel_compiler
backend/session backend/session
runtime/device runtime/device
runtime/framework
runtime/hardware runtime/hardware
runtime/hccl_adapter runtime/hccl_adapter
frontend/optimizer frontend/optimizer

View File

@ -44,6 +44,7 @@ class TaskGenerator;
namespace gpu { namespace gpu {
class GPUKernelRuntime; class GPUKernelRuntime;
class GPUMemoryManager; class GPUMemoryManager;
class GPUDeviceContext;
} // namespace gpu } // namespace gpu
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore
@ -107,6 +108,7 @@ class DeviceAddress : public mindspore::DeviceSync {
friend class mindspore::device::cpu::CPUDeviceContext; friend class mindspore::device::cpu::CPUDeviceContext;
friend class mindspore::device::gpu::GPUKernelRuntime; friend class mindspore::device::gpu::GPUKernelRuntime;
friend class mindspore::device::gpu::GPUMemoryManager; friend class mindspore::device::gpu::GPUMemoryManager;
friend class mindspore::device::gpu::GPUDeviceContext;
friend class mindspore::device::ascend::AscendKernelRuntime; friend class mindspore::device::ascend::AscendKernelRuntime;
friend class mindspore::device::ascend::AscendMemoryManager; friend class mindspore::device::ascend::AscendMemoryManager;
friend class mindspore::device::ascend::DataDumper; friend class mindspore::device::ascend::DataDumper;

View File

@ -0,0 +1,8 @@
include_directories(${CMAKE_SOURCE_DIR}/mindspore/core/mindrt/include)
include_directories(${CMAKE_SOURCE_DIR}/mindspore/core/mindrt/src)
file(GLOB_RECURSE FRAMEWORK_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${FRAMEWORK_SRC_LIST}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_RUNTIME_FRAMEWORK)
add_library(_mindspore_runtime_framework_obj OBJECT ${FRAMEWORK_SRC_LIST})

View File

@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_ACTOR_COMMON_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_ACTOR_COMMON_H_
#include <utility>
#include "mindrt/include/actor/op_actor.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace runtime {
// The execution result of actor.
constexpr int kSuccess = 0;
constexpr int kFailure = 1;
#define SET_OPCONTEXT_FAIL_RET_WITH_ERROR(op_context, message) \
{ \
MS_LOG(ERROR) << message; \
op_context.SetFailed(kFailure); \
return; \
}
#define SET_OPCONTEXT_SUCCESS_RET(op_context) \
{ \
op_context.SetSuccess(kSuccess); \
return; \
}
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_ACTOR_COMMON_H_

View File

@ -0,0 +1,163 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/framework/actor/data_source_actor.h"
#include "runtime/framework/actor/kernel_actor.h"
#include "runtime/framework/actor/memory_manager_actor.h"
#include "mindrt/include/async/async.h"
#include "common/trans.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace runtime {
void DataSourceActor::FetchData(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
if (buffers_.size() == buffer_capacity_) {
// Send output to trigger computing and free memory.
SendOutput(context);
FreeMemory(context);
buffers_.pop();
return;
}
// Construct device tensors and fill to the buffers from member nodes.
FillDataBuffer();
if (buffers_.size() == 0) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty.");
}
// Allocate memory for device tensors.
AllocateMemory(context);
}
void DataSourceActor::AllocateMemory(OpContext<DeviceTensor> *context) {
auto device_tensors = buffers_.back();
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, device_tensors, device_context_, context, GetAID());
}
void DataSourceActor::FreeMemory(OpContext<DeviceTensor> *context) {
auto device_tensors = buffers_.front();
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, device_tensors, device_context_, context);
}
void DataSourceActor::SendOutput(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
if (buffers_.size() == 0) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty.");
}
// Send output data.
auto output_device_tensors = buffers_.front();
for (auto &op_arrow : output_op_arrows_) {
MS_EXCEPTION_IF_NULL(op_arrow);
if (IntToSize(op_arrow->from_output_index_) >= output_device_tensors.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The output index is of range.");
}
auto device_address = output_device_tensors[op_arrow->from_output_index_];
auto data = std::make_shared<OpData<DeviceTensor>>(op_arrow->to_op_id_, device_address, op_arrow->to_input_index_);
Async(op_arrow->to_op_id_, &KernelActor::RunOpData, data, context);
}
}
void DeviceQueueDataSourceActor::FillDataBuffer() {
// Construct device tensors.
std::vector<DeviceTensor *> device_tensors;
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(data_kernel_); ++i) {
auto device_address = AnfAlgo::GetMutableOutputAddr(data_kernel_, i, false);
MS_EXCEPTION_IF_NULL(device_address);
device_tensors.emplace_back(device_address.get());
}
buffers_.push(device_tensors);
}
void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(device_context_);
if (buffers_.size() == 0) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty.");
}
// Construct outputs of data kernel launching.
auto device_tensors = buffers_.back();
std::vector<AddressPtr> kernel_outputs;
for (auto &device_tensor : device_tensors) {
MS_EXCEPTION_IF_NULL(device_tensor);
kernel_outputs.emplace_back(std::make_shared<Address>(device_tensor->GetMutablePtr(), device_tensor->GetSize()));
}
// Copy data from device queue by data kernel launching.
std::vector<AddressPtr> empty_address;
auto kernel_mod = AnfAlgo::GetKernelMod(data_kernel_);
auto ret = device_context_->LaunchKernel(kernel_mod, empty_address, empty_address, kernel_outputs);
if (!ret) {
std::string error_info = "Launch kernel failed: " + data_kernel_->ToString();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
// Send output to trigger computing and free memory.
SendOutput(context);
FreeMemory(context);
buffers_.pop();
}
void HostQueueDataSourceActor::FillDataBuffer() {
// Construct device tensors.
std::vector<DeviceTensor *> device_tensors;
for (auto &data_node : data_nodes_) {
auto device_address = AnfAlgo::GetMutableOutputAddr(data_node, 0, false);
MS_EXCEPTION_IF_NULL(device_address);
device_tensors.emplace_back(device_address.get());
}
buffers_.push(device_tensors);
}
void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
if (buffers_.size() == 0) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty.");
}
// Get host tensors from host queue and get device tensors from buffers.
MS_EXCEPTION_IF_NULL(host_queue_);
auto host_tensors = host_queue_->PullData();
auto device_tensors = buffers_.back();
if (host_tensors.size() != device_tensors.size()) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context),
"The length of host tensors is not equal to the length of device tensors.");
}
// Copy data from host tensor to device tensor.
for (size_t i = 0; i < host_tensors.size(); ++i) {
auto host_tensor = host_tensors[i];
auto device_tensor = device_tensors[i];
MS_EXCEPTION_IF_NULL(host_tensor);
MS_EXCEPTION_IF_NULL(device_tensor);
if (!device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(data_nodes_[i], 0),
LongToSize(host_tensor->data().nbytes()), host_tensor->data_type(),
host_tensor->data_c())) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "SyncHostToDevice failed.");
}
}
// Send output to trigger computing and free memory.
SendOutput(context);
FreeMemory(context);
buffers_.pop();
}
} // namespace runtime
} // namespace mindspore

View File

@ -22,43 +22,73 @@
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include <queue> #include <queue>
#include "mindrt/include/actor/op_actor.h" #include <utility>
#include "mindrt/include/async/future.h" #include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/actor/memory_interface_actor.h"
#include "runtime/hardware/device_context.h"
#include "runtime/framework/device_tensor_store.h" #include "runtime/framework/device_tensor_store.h"
#include "runtime/framework/host_tensor_queue.h" #include "runtime/framework/host_tensor_queue.h"
#include "base/base.h" #include "base/base.h"
namespace mindspore { namespace mindspore {
namespace runtime { namespace runtime {
// The data source actor is used to fetch data and process them into device tensors, using mindspore::device::DeviceContext;
// and then send them to kernel actor.
class DataSourceActor : public ActorBase { // The data source actor is used to fetch data from data source and process them into device tensors,
// and then send them to kernel actor. The processing flow is FetchData -> FillDataBuffer -> AllocateMemory
// -> OnMemoryAllocFinish -> SendOutput -> FreeMemory.
class DataSourceActor : public MemoryInterfaceActor {
public: public:
DataSourceActor(std::string name, size_t buffer_capacity) : ActorBase(name), buffer_capacity_(buffer_capacity) {} DataSourceActor(std::string name, size_t buffer_capacity, const DeviceContext *device_context,
const AID memory_manager_aid)
: MemoryInterfaceActor(name),
buffer_capacity_(buffer_capacity),
device_context_(device_context),
memory_manager_aid_(memory_manager_aid) {}
virtual ~DataSourceActor() = default; virtual ~DataSourceActor() = default;
// The process entry of data processing. // The process entry of data processing.
virtual void FetchData(OpContext<DeviceTensor> *context) = 0; void FetchData(OpContext<DeviceTensor> *context);
// The memory related operation interface.
void AllocateMemory(OpContext<DeviceTensor> *context) override;
void FreeMemory(OpContext<DeviceTensor> *context) override;
// Copy data from data source to the device tensor buffer of actor after memory alloc finished.
void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override{};
protected: protected:
// Construct the device tensors and fill to device tensor buffer from the member nodes during the data fetching.
virtual void FillDataBuffer() = 0;
// Send output to downstream actors to trigger computing after fetching data finished.
void SendOutput(OpContext<DeviceTensor> *context);
// To trigger kernel actors running by op arrows. // To trigger kernel actors running by op arrows.
std::vector<OpArrowPtr> output_op_arrows_; std::vector<OpArrowPtr> output_op_arrows_;
// The buffers store the data. // The buffers store the device tensors.
std::queue<std::vector<DeviceTensorPtr>> buffers_; std::queue<std::vector<DeviceTensor *>> buffers_;
size_t buffer_capacity_; size_t buffer_capacity_;
// The sequential number of corresponding batch data. // The device interface of data copy.
std::queue<uuids::uuid *> sequential_nums_; const DeviceContext *device_context_;
// The id of memory manager actor. Send message to it for alloc and free memory during the data processing.
const AID memory_manager_aid_;
}; };
// The class represents that the data source is device queue. // The class represents that the data source is device queue.
class DeviceQueueDataSourceActor : public DataSourceActor { class DeviceQueueDataSourceActor : public DataSourceActor {
public: public:
DeviceQueueDataSourceActor(std::string name, size_t buffer_capacity) : DataSourceActor(name, buffer_capacity) {} DeviceQueueDataSourceActor(std::string name, size_t buffer_capacity, const DeviceContext *device_context,
virtual ~DeviceQueueDataSourceActor() = default; const AID memory_manager_aid)
: DataSourceActor(name, buffer_capacity, device_context, memory_manager_aid) {}
~DeviceQueueDataSourceActor() override = default;
void FetchData(OpContext<DeviceTensor> *context) override; void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override;
protected:
void FillDataBuffer() override;
private: private:
friend class GraphScheduler; friend class GraphScheduler;
@ -70,11 +100,15 @@ class DeviceQueueDataSourceActor : public DataSourceActor {
// The class represents that the data source is host queue. // The class represents that the data source is host queue.
class HostQueueDataSourceActor : public DataSourceActor { class HostQueueDataSourceActor : public DataSourceActor {
public: public:
HostQueueDataSourceActor(std::string name, size_t buffer_capacity, HostTensorQueuePtr host_queue) HostQueueDataSourceActor(std::string name, size_t buffer_capacity, const DeviceContext *device_context,
: DataSourceActor(name, buffer_capacity), host_queue_(host_queue) {} const AID memory_manager_aid, HostTensorQueuePtr host_queue)
virtual ~HostQueueDataSourceActor() = default; : DataSourceActor(name, buffer_capacity, device_context, memory_manager_aid), host_queue_(host_queue) {}
~HostQueueDataSourceActor() override = default;
void FetchData(OpContext<DeviceTensor> *context) override; void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override;
protected:
void FillDataBuffer() override;
private: private:
friend class GraphScheduler; friend class GraphScheduler;

View File

@ -0,0 +1,190 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/framework/actor/kernel_actor.h"
#include "runtime/framework/actor/memory_manager_actor.h"
#include "mindrt/include/async/async.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace runtime {
void KernelActor::RunOpData(OpDataPtr<DeviceTensor> input_data, OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto sequential_num = context->sequential_num_;
input_op_datas_[sequential_num].emplace_back(input_data);
// When all the input data are collected, then allocate memory and callback launch.
if (CheckLaunchCondition(context)) {
FetchInputDeviceTensor(context);
FetchOutputDeviceTensor();
FetchWorkspaceDeviceTensor();
AllocateMemory(context);
}
}
void KernelActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto sequential_num = context->sequential_num_;
input_op_controls_[sequential_num].emplace_back(input_control);
// When all the input data are collected, then allocate memory and callback launch.
if (CheckLaunchCondition(context)) {
FetchInputDeviceTensor(context);
FetchOutputDeviceTensor();
FetchWorkspaceDeviceTensor();
AllocateMemory(context);
}
}
void KernelActor::AllocateMemory(OpContext<DeviceTensor> *context) {
std::vector<DeviceTensor *> alloc_list(output_device_tensors_);
alloc_list.insert(alloc_list.end(), workspace_device_tensors_.begin(), workspace_device_tensors_.end());
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, alloc_list, device_context_, context, GetAID());
}
void KernelActor::FreeMemory(OpContext<DeviceTensor> *context) {
std::vector<DeviceTensor *> free_list(input_device_tensors_);
free_list.insert(free_list.end(), output_device_tensors_.begin(), output_device_tensors_.end());
free_list.insert(free_list.end(), workspace_device_tensors_.begin(), workspace_device_tensors_.end());
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, free_list, device_context_, context);
}
void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(kernel_);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel_);
std::vector<AddressPtr> kernel_inputs;
std::vector<AddressPtr> kernel_outputs;
std::vector<AddressPtr> kernel_workspaces;
FetchLaunchArgs(&kernel_inputs, &kernel_outputs, &kernel_workspaces);
MS_EXCEPTION_IF_NULL(device_context_);
auto ret = device_context_->LaunchKernel(kernel_mod, kernel_inputs, kernel_workspaces, kernel_outputs);
if (!ret) {
std::string error_info = "Launch kernel failed: " + kernel_->ToString();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
SendOutput(context);
FreeMemory(context);
}
bool KernelActor::CheckLaunchCondition(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
auto data_iter = input_op_datas_.find(context->sequential_num_);
if (data_iter == input_op_datas_.end()) {
return false;
}
if (data_iter->second.size() != input_datas_num_) {
return false;
}
}
if (input_controls_num_ != 0) {
auto control_iter = input_op_controls_.find(context->sequential_num_);
if (control_iter == input_op_controls_.end()) {
return false;
}
if (control_iter->second.size() != input_controls_num_) {
return false;
}
}
return true;
}
void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto input_size = input_datas_num_ + device_tensor_store_keys_.size();
input_device_tensors_.resize(input_size);
auto data_iter = input_op_datas_.find(context->sequential_num_);
if (data_iter != input_op_datas_.end()) {
for (auto &input_data : data_iter->second) {
MS_EXCEPTION_IF_NULL(input_data);
input_device_tensors_[input_data->index_] = input_data->data_;
}
}
for (auto &device_tensor_store_key : device_tensor_store_keys_) {
auto device_tensor = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second);
input_device_tensors_[device_tensor_store_key.first] = device_tensor.get();
}
}
void KernelActor::FetchOutputDeviceTensor() {
output_device_tensors_.clear();
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel_); ++i) {
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel_, i, false);
MS_EXCEPTION_IF_NULL(device_address);
output_device_tensors_.emplace_back(device_address.get());
}
}
void KernelActor::FetchWorkspaceDeviceTensor() {
workspace_device_tensors_.clear();
auto kernel_mod = AnfAlgo::GetKernelMod(kernel_);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
for (size_t i = 0; i < workspace_sizes.size(); ++i) {
if (workspace_sizes[i] != 0) {
auto device_address = AnfAlgo::GetMutableWorkspaceAddr(kernel_, i);
MS_EXCEPTION_IF_NULL(device_address);
workspace_device_tensors_.emplace_back(device_address.get());
}
}
}
void KernelActor::FetchLaunchArgs(std::vector<AddressPtr> *kernel_inputs, std::vector<AddressPtr> *kernel_outputs,
std::vector<AddressPtr> *kernel_workspaces) {
MS_EXCEPTION_IF_NULL(kernel_inputs);
MS_EXCEPTION_IF_NULL(kernel_outputs);
MS_EXCEPTION_IF_NULL(kernel_workspaces);
for (auto &input : input_device_tensors_) {
MS_EXCEPTION_IF_NULL(input);
kernel_inputs->emplace_back(std::make_shared<Address>(input->GetMutablePtr(), input->GetSize()));
}
for (auto &output : output_device_tensors_) {
MS_EXCEPTION_IF_NULL(output);
kernel_outputs->emplace_back(std::make_shared<Address>(output->GetMutablePtr(), output->GetSize()));
}
for (auto &workspace : workspace_device_tensors_) {
MS_EXCEPTION_IF_NULL(workspace);
kernel_workspaces->emplace_back(std::make_shared<Address>(workspace->GetMutablePtr(), workspace->GetSize()));
}
}
void KernelActor::SendOutput(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
// Send output data.
for (auto &op_arrow : output_op_arrows_) {
MS_EXCEPTION_IF_NULL(op_arrow);
if (IntToSize(op_arrow->from_output_index_) >= output_device_tensors_.size()) {
std::string error_info = "The output index is out of range: " + kernel_->ToString();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
auto device_address = output_device_tensors_[op_arrow->from_output_index_];
auto data = std::make_shared<OpData<DeviceTensor>>(op_arrow->to_op_id_, device_address, op_arrow->to_input_index_);
Async(op_arrow->to_op_id_, &KernelActor::RunOpData, data, context);
}
// Send output control.
auto source_aid = const_cast<AID *>(&GetAID());
for (auto &output_control : output_op_controls_) {
Async(output_control, &OpActor::RunOpControl, source_aid, context);
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -22,7 +22,8 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <unordered_map> #include <unordered_map>
#include "mindrt/include/actor/op_actor.h" #include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/actor/memory_interface_actor.h"
#include "runtime/hardware/device_context.h" #include "runtime/hardware/device_context.h"
#include "runtime/framework/device_tensor_store.h" #include "runtime/framework/device_tensor_store.h"
#include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/kernel.h"
@ -31,38 +32,47 @@
namespace mindspore { namespace mindspore {
namespace runtime { namespace runtime {
using mindspore::device::DeviceContext; using mindspore::device::DeviceContext;
using mindspore::kernel::Address;
using mindspore::kernel::AddressPtr; using mindspore::kernel::AddressPtr;
// The kernel actor is used to receive the device tensors and control info to luanch kernel. // The kernel actor is used to receive the device tensors and control info to luanch kernel.
class KernelActor : public OpActor<DeviceTensor> { // The processing flow is RunOpData/RunOpControl -> CheckLaunchCondition -> AllocateMemory
// -> OnMemoryAllocFinish -> LaunchKernel -> SendOutput -> FreeMemory.
class KernelActor : public MemoryInterfaceActor {
public: public:
KernelActor(std::string name, CNodePtr kernel, const DeviceContext *device_context) KernelActor(std::string name, CNodePtr kernel, const DeviceContext *device_context, const AID memory_manager_aid)
: OpActor(name), kernel_(kernel), device_context_(device_context), input_datas_num_(0), input_controls_num_(0) {} : MemoryInterfaceActor(name),
virtual ~KernelActor() = default; kernel_(kernel),
device_context_(device_context),
memory_manager_aid_(memory_manager_aid),
input_datas_num_(0),
input_controls_num_(0) {}
~KernelActor() override = default;
// The kernel actor run when receive the input data. // The kernel actor run when receive the input data.
void RunOpData(OpDataPtr<DeviceTensor> input_data, OpContext<DeviceTensor> *context) override; void RunOpData(OpDataPtr<DeviceTensor> input_data, OpContext<DeviceTensor> *context) override;
// The kernel actor run when receive the input control. // The kernel actor run when receive the input control.
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override; void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override;
// The memory related operation interface.
void AllocateMemory(OpContext<DeviceTensor> *context) override;
void FreeMemory(OpContext<DeviceTensor> *context) override;
// The real kernel launch processing after memory alloc finished.
void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override;
private: private:
friend class GraphScheduler; friend class GraphScheduler;
// Check whether satisfy the condition for launch. // Check whether satisfy the condition for launch.
bool CheckLaunchCondition(const uuids::uuid *sequential_num); bool CheckLaunchCondition(OpContext<DeviceTensor> *context);
// Fetch the args of kernel launch. // Fetch the args of kernel launch.
void FetchLaunchArgs(std::vector<AddressPtr> *kernel_inputs, std::vector<AddressPtr> *kernel_outputs, void FetchLaunchArgs(std::vector<AddressPtr> *kernel_inputs, std::vector<AddressPtr> *kernel_outputs,
std::vector<AddressPtr> *kernel_workspaces); std::vector<AddressPtr> *kernel_workspaces);
// The real kernel launch processing.
void Launch(OpContext<DeviceTensor> *context);
// Send output data and output controls when finish kernel launch. // Send output data and output controls when finish kernel launch.
void SendOutput(OpContext<DeviceTensor> *context); void SendOutput(OpContext<DeviceTensor> *context);
void AllocateMemory(OpContext<DeviceTensor> *context);
void FreeMemory(OpContext<DeviceTensor> *context);
// Fetch the device tensor for launch. // Fetch the device tensor for launch.
void FetchInputDeviceTensor(const uuids::uuid *sequential_num); void FetchInputDeviceTensor(OpContext<DeviceTensor> *context);
void FetchOutputDeviceTensor(); void FetchOutputDeviceTensor();
void FetchWorkspaceDeviceTensor(); void FetchWorkspaceDeviceTensor();
@ -70,6 +80,9 @@ class KernelActor : public OpActor<DeviceTensor> {
// The device interface of kernel launch. // The device interface of kernel launch.
const DeviceContext *device_context_; const DeviceContext *device_context_;
// The id of memory manager actor. Send message to it for alloc and free memory during the kernel launch.
const AID memory_manager_aid_;
// The dependent input data number. // The dependent input data number.
size_t input_datas_num_; size_t input_datas_num_;
// The dependent input controls number. // The dependent input controls number.
@ -79,9 +92,9 @@ class KernelActor : public OpActor<DeviceTensor> {
std::vector<std::pair<size_t, void *>> device_tensor_store_keys_; std::vector<std::pair<size_t, void *>> device_tensor_store_keys_;
// The device tensors for launch. // The device tensors for launch.
std::vector<DeviceTensorPtr> input_device_tensors_; std::vector<DeviceTensor *> input_device_tensors_;
std::vector<DeviceTensorPtr> output_device_tensors_; std::vector<DeviceTensor *> output_device_tensors_;
std::vector<DeviceTensorPtr> workspace_device_tensors_; std::vector<DeviceTensor *> workspace_device_tensors_;
}; };
using KernelActorPtr = std::shared_ptr<KernelActor>; using KernelActorPtr = std::shared_ptr<KernelActor>;

View File

@ -0,0 +1,47 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/framework/actor/loop_count_actor.h"
#include "runtime/framework/actor/data_source_actor.h"
#include "runtime/framework/actor/kernel_actor.h"
#include "mindrt/include/async/async.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace runtime {
void LoopCountActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
auto sequential_num = context->sequential_num_;
input_op_controls_[sequential_num].emplace_back(input_control);
if (input_op_controls_[sequential_num].size() == input_controls_num_) {
current_count_++;
if (current_count_ == loop_count_) {
current_count_ = 0;
SET_OPCONTEXT_SUCCESS_RET((*context));
}
// Send output control.
for (auto &data_source_aid : data_source_aids_) {
Async(data_source_aid, &DataSourceActor::FetchData, context);
}
auto source_aid = const_cast<AID *>(&GetAID());
for (auto &kernel_aid : no_input_kernel_aids_) {
Async(kernel_aid, &KernelActor::RunOpControl, source_aid, context);
}
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -21,7 +21,7 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include "mindrt/include/actor/op_actor.h" #include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/device_tensor_store.h" #include "runtime/framework/device_tensor_store.h"
namespace mindspore { namespace mindspore {
@ -32,7 +32,7 @@ class LoopCountActor : public OpActor<DeviceTensor> {
public: public:
LoopCountActor(std::string name, size_t loop_count) LoopCountActor(std::string name, size_t loop_count)
: OpActor(name), loop_count_(loop_count), current_count_(0), input_controls_num_(0) {} : OpActor(name), loop_count_(loop_count), current_count_(0), input_controls_num_(0) {}
virtual ~LoopCountActor() = default; ~LoopCountActor() override = default;
// The loop count actor run when receive the input control. // The loop count actor run when receive the input control.
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override; void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override;

View File

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_MEMORY_INTERFACE_ACTOR_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_MEMORY_INTERFACE_ACTOR_H_
#include <utility>
#include <string>
#include "mindrt/include/actor/op_actor.h"
#include "runtime/framework/device_tensor_store.h"
namespace mindspore {
namespace runtime {
// The actor represents a set of common memory related operations of actor.
class MemoryInterfaceActor : public OpActor<DeviceTensor> {
public:
explicit MemoryInterfaceActor(std::string name) : OpActor(name) {}
virtual ~MemoryInterfaceActor() = default;
virtual void AllocateMemory(OpContext<DeviceTensor> *context) = 0;
virtual void FreeMemory(OpContext<DeviceTensor> *context) = 0;
virtual void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) = 0;
};
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_MEMORY_INTERFACE_ACTOR_H_

View File

@ -0,0 +1,62 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/framework/actor/memory_manager_actor.h"
#include "runtime/framework/actor/data_source_actor.h"
#include "runtime/framework/actor/kernel_actor.h"
#include "mindrt/include/async/async.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace runtime {
void MemoryManagerActor::AllocateMemory(std::vector<DeviceTensor *> alloc_list, const DeviceContext *device_context,
OpContext<DeviceTensor> *op_context, const AID from_aid) {
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(op_context);
for (auto &device_tensor : alloc_list) {
MS_EXCEPTION_IF_NULL(device_tensor);
if (device_tensor->GetPtr() != nullptr) {
continue;
}
// Allocate memory through the device context.
if (!device_context->AllocateMemory(device_tensor, device_tensor->GetSize())) {
std::string error_info = "Device memory isn't enough and alloc failed, actor name: " + from_aid.Name() +
", alloc size: " + std::to_string(device_tensor->GetSize());
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*op_context), error_info);
}
}
// Call back to the from actor to process after memory allocation finished.
Async(from_aid, &MemoryInterfaceActor::OnMemoryAllocFinish, op_context);
}
void MemoryManagerActor::FreeMemory(std::vector<DeviceTensor *> free_list, const DeviceContext *device_context,
OpContext<DeviceTensor> *) {
MS_EXCEPTION_IF_NULL(device_context);
for (auto &device_tensor : free_list) {
MS_EXCEPTION_IF_NULL(device_tensor);
// The reference count is decremented to zero to free memory, and reset to the original count.
device_tensor->DecreaseRefCountUsed();
if (device_tensor->ref_count_dynamic_used() == 0) {
// Free memory through the device context.
device_context->FreeMemory(device_tensor);
device_tensor->ResetRefCountUsed();
}
}
}
} // namespace runtime
} // namespace mindspore

View File

@ -21,7 +21,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "mindrt/include/actor/op_actor.h" #include "runtime/framework/actor/actor_common.h"
#include "runtime/framework/device_tensor_store.h" #include "runtime/framework/device_tensor_store.h"
#include "runtime/hardware/device_context.h" #include "runtime/hardware/device_context.h"
@ -36,10 +36,10 @@ class MemoryManagerActor : public ActorBase {
~MemoryManagerActor() override = default; ~MemoryManagerActor() override = default;
// The process entry of memory alloc. // The process entry of memory alloc.
bool AllocateMemory(std::vector<DeviceTensorPtr> alloc_list, const DeviceContext *device_context, void AllocateMemory(std::vector<DeviceTensor *> alloc_list, const DeviceContext *device_context,
OpContext<DeviceTensor> *op_context); OpContext<DeviceTensor> *op_context, const AID from_aid);
// The process entry of memory free. // The process entry of memory free.
void FreeMemory(std::vector<DeviceTensorPtr> free_list, const DeviceContext *device_context, void FreeMemory(std::vector<DeviceTensor *> free_list, const DeviceContext *device_context,
OpContext<DeviceTensor> *op_context); OpContext<DeviceTensor> *op_context);
}; };
} // namespace runtime } // namespace runtime

View File

@ -15,6 +15,7 @@
*/ */
#include "runtime/framework/graph_scheduler.h" #include "runtime/framework/graph_scheduler.h"
#include "runtime/framework/actor/memory_manager_actor.h"
#include "mindrt/src/actor/actormgr.h" #include "mindrt/src/actor/actormgr.h"
#include "mindrt/include/async/async.h" #include "mindrt/include/async/async.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
@ -102,6 +103,23 @@ void UpdateRefCount(const AnfNodePtr &node, size_t output_idx) {
} }
} // namespace } // namespace
void GraphScheduler::Initialize() {
if (init_) {
return;
}
init_ = true;
// Create memory manager actor.
auto memory_manager_actor = std::make_shared<MemoryManagerActor>();
MS_EXCEPTION_IF_NULL(memory_manager_actor);
memory_manager_aid_ = memory_manager_actor->GetAID();
// Schedule memory manager actor, bind single thread to response to memory alloc and free quickly.
auto base_actor = static_cast<ActorReference>(memory_manager_actor);
auto actorMgr = ActorMgr::GetActorMgrRef();
MS_EXCEPTION_IF_NULL(actorMgr);
(void)actorMgr->Spawn(base_actor, false);
}
ActorSet *GraphScheduler::Transform(const KernelGraphPtr &graph, const DeviceContext *device_context, ActorSet *GraphScheduler::Transform(const KernelGraphPtr &graph, const DeviceContext *device_context,
const std::vector<tensor::TensorPtr> *input_tensors, const std::vector<tensor::TensorPtr> *input_tensors,
GraphExecutionStrategy strategy) { GraphExecutionStrategy strategy) {
@ -191,7 +209,7 @@ ActorSetPtr GraphScheduler::Build(const KernelGraphPtr &graph, const DeviceConte
auto actor_set = std::make_shared<ActorSet>(); auto actor_set = std::make_shared<ActorSet>();
MS_EXCEPTION_IF_NULL(actor_set); MS_EXCEPTION_IF_NULL(actor_set);
auto data_source_actors = BuildDataSourceActor(graph); auto data_source_actors = BuildDataSourceActor(graph, device_context);
actor_set->data_source_actors_.swap(data_source_actors); actor_set->data_source_actors_.swap(data_source_actors);
auto kernel_actors = BuildKernelActor(graph, device_context); auto kernel_actors = BuildKernelActor(graph, device_context);
@ -251,7 +269,8 @@ void GraphScheduler::Link(ActorSet *actor_set, const KernelGraphPtr &graph, Grap
LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), graph); LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), graph);
} }
std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const KernelGraphPtr &graph) { std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const KernelGraphPtr &graph,
const DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
std::vector<DataSourceActorPtr> data_source_actors; std::vector<DataSourceActorPtr> data_source_actors;
@ -265,7 +284,8 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Kerne
MS_LOG(INFO) << "Create host queue data source actor: " << actor_name; MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
auto host_queue = std::make_shared<HostTensorQueue>(); auto host_queue = std::make_shared<HostTensorQueue>();
graph_to_host_queue_.emplace(graph, host_queue); graph_to_host_queue_.emplace(graph, host_queue);
host_queue_ds_actor = std::make_shared<HostQueueDataSourceActor>(actor_name, 1, host_queue); host_queue_ds_actor =
std::make_shared<HostQueueDataSourceActor>(actor_name, 1, device_context, memory_manager_aid_, host_queue);
data_source_actors.emplace_back(host_queue_ds_actor); data_source_actors.emplace_back(host_queue_ds_actor);
} }
host_queue_ds_actor->data_nodes_.emplace_back(input_node); host_queue_ds_actor->data_nodes_.emplace_back(input_node);
@ -279,7 +299,8 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Kerne
if (iter != execution_order.end()) { if (iter != execution_order.end()) {
auto actor_name = graph->ToString() + "_" + "DeviceQueueDataSourceActor"; auto actor_name = graph->ToString() + "_" + "DeviceQueueDataSourceActor";
MS_LOG(INFO) << "Create queue data source actor: " << actor_name; MS_LOG(INFO) << "Create queue data source actor: " << actor_name;
auto device_queue_ds_actor = std::make_shared<DeviceQueueDataSourceActor>(actor_name, 1); auto device_queue_ds_actor =
std::make_shared<DeviceQueueDataSourceActor>(actor_name, 1, device_context, memory_manager_aid_);
MS_EXCEPTION_IF_NULL(device_queue_ds_actor); MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
data_source_actors.emplace_back(device_queue_ds_actor); data_source_actors.emplace_back(device_queue_ds_actor);
device_queue_ds_actor->data_kernel_ = *iter; device_queue_ds_actor->data_kernel_ = *iter;
@ -295,7 +316,8 @@ std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const KernelGraphPt
auto execution_order = graph->execution_order(); auto execution_order = graph->execution_order();
for (auto &kernel : execution_order) { for (auto &kernel : execution_order) {
if (IsKernelActor(kernel)) { if (IsKernelActor(kernel)) {
auto kernel_actor = std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context); auto kernel_actor =
std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_);
MS_EXCEPTION_IF_NULL(kernel_actor); MS_EXCEPTION_IF_NULL(kernel_actor);
kernel_actors.emplace_back(kernel_actor); kernel_actors.emplace_back(kernel_actor);
} }

View File

@ -61,6 +61,9 @@ class GraphScheduler {
return instance; return instance;
} }
// The memory manager creating and scheduling.
void Initialize();
// Transform graph to actor DAG, contains build and link. // Transform graph to actor DAG, contains build and link.
ActorSet *Transform(const KernelGraphPtr &graph, const DeviceContext *device_context, ActorSet *Transform(const KernelGraphPtr &graph, const DeviceContext *device_context,
const std::vector<tensor::TensorPtr> *input_tensors = nullptr, const std::vector<tensor::TensorPtr> *input_tensors = nullptr,
@ -87,7 +90,8 @@ class GraphScheduler {
void Link(ActorSet *actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy); void Link(ActorSet *actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy);
// The processing of actors build. // The processing of actors build.
std::vector<DataSourceActorPtr> BuildDataSourceActor(const KernelGraphPtr &graph); std::vector<DataSourceActorPtr> BuildDataSourceActor(const KernelGraphPtr &graph,
const DeviceContext *device_context);
std::vector<KernelActorPtr> BuildKernelActor(const KernelGraphPtr &graph, const DeviceContext *device_context); std::vector<KernelActorPtr> BuildKernelActor(const KernelGraphPtr &graph, const DeviceContext *device_context);
std::vector<KernelActorPtr> BuildNoInputKernelActor(const KernelGraphPtr &graph); std::vector<KernelActorPtr> BuildNoInputKernelActor(const KernelGraphPtr &graph);
LoopCountActorPtr BuildLoopCountActor(const KernelGraphPtr &graph); LoopCountActorPtr BuildLoopCountActor(const KernelGraphPtr &graph);
@ -114,6 +118,11 @@ class GraphScheduler {
// The second element of pair represents the output index of kernel actor corresponding to the device tensor. // The second element of pair represents the output index of kernel actor corresponding to the device tensor.
std::unordered_map<DeviceTensorPtr, std::pair<KernelActorPtr, int>> device_address_to_actor_; std::unordered_map<DeviceTensorPtr, std::pair<KernelActorPtr, int>> device_address_to_actor_;
// The id of memory manager actor.
AID memory_manager_aid_;
bool init_{false};
}; };
} // namespace runtime } // namespace runtime
} // namespace mindspore } // namespace mindspore

View File

@ -35,12 +35,12 @@ bool CPUDeviceContext::Initialize() {
return true; return true;
} }
bool CPUDeviceContext::AllocateMemory(const DeviceAddressPtr &address, size_t size) const { bool CPUDeviceContext::AllocateMemory(DeviceAddress *const &address, size_t size) const {
address->ptr_ = static_cast<CPUMemoryManager *>(mem_manager_.get())->StaticMemMalloc(size); address->ptr_ = static_cast<CPUMemoryManager *>(mem_manager_.get())->StaticMemMalloc(size);
return true; return true;
} }
void CPUDeviceContext::FreeMemory(const DeviceAddressPtr &address) const { void CPUDeviceContext::FreeMemory(DeviceAddress *const &address) const {
static_cast<CPUMemoryManager *>(mem_manager_.get())->MemFree(address->ptr_); static_cast<CPUMemoryManager *>(mem_manager_.get())->MemFree(address->ptr_);
address->ptr_ = nullptr; address->ptr_ = nullptr;
} }

View File

@ -33,8 +33,8 @@ class CPUDeviceContext : public DeviceContext {
bool Initialize() override; bool Initialize() override;
bool AllocateMemory(const DeviceAddressPtr &address, size_t size) const override; bool AllocateMemory(DeviceAddress *const &address, size_t size) const override;
void FreeMemory(const DeviceAddressPtr &address) const override; void FreeMemory(DeviceAddress *const &address) const override;
void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const override; void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const override;
void CreateKernel(const std::vector<CNodePtr> &nodes) const override; void CreateKernel(const std::vector<CNodePtr> &nodes) const override;

View File

@ -52,13 +52,13 @@ class DeviceContext {
virtual void Destroy() {} virtual void Destroy() {}
// Relevant function to allocate and free device memory. // Relevant function to allocate and free device memory.
virtual bool AllocateMemory(const DeviceAddressPtr &address, size_t size) const = 0; virtual bool AllocateMemory(DeviceAddress *const &address, size_t size) const = 0;
virtual void FreeMemory(const DeviceAddressPtr &address) const = 0; virtual void FreeMemory(DeviceAddress *const &address) const = 0;
// Allocate continuous device memory end to end into 'addr_list'. // Allocate continuous device memory end to end into 'addr_list'.
// Communication operators may need continuous memory for input and output // Communication operators may need continuous memory for input and output
// to optimize the communication performance. // to optimize the communication performance.
virtual bool AllocateContinuousMemory(const DeviceAddressPtrList &addr_list, size_t total_size, virtual bool AllocateContinuousMemory(const std::vector<DeviceAddress *> &addr_list, size_t total_size,
const std::vector<size_t> &size_list) const { const std::vector<size_t> &size_list) const {
return true; return true;
} }

View File

@ -115,15 +115,41 @@ void GPUDeviceContext::Destroy() {
} }
} }
bool GPUDeviceContext::AllocateMemory(const DeviceAddressPtr &address, size_t size) const { bool GPUDeviceContext::AllocateMemory(DeviceAddress *const &address, size_t size) const {
return mem_manager_->MallocMemFromMemPool(address, size); MS_EXCEPTION_IF_NULL(address);
auto device_ptr = mem_manager_->MallocMemFromMemPool(size);
if (!device_ptr) {
return false;
}
address->ptr_ = device_ptr;
address->size_ = size;
address->from_mem_pool_ = true;
return true;
} }
void GPUDeviceContext::FreeMemory(const DeviceAddressPtr &address) const { mem_manager_->FreeMemFromMemPool(address); } void GPUDeviceContext::FreeMemory(DeviceAddress *const &address) const {
MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(address->ptr_);
mem_manager_->FreeMemFromMemPool(address->ptr_);
address->ptr_ = nullptr;
}
bool GPUDeviceContext::AllocateContinuousMemory(const DeviceAddressPtrList &addr_list, size_t total_size, bool GPUDeviceContext::AllocateContinuousMemory(const std::vector<DeviceAddress *> &addr_list, size_t total_size,
const std::vector<size_t> &size_list) const { const std::vector<size_t> &size_list) const {
return mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list); auto device_ptr_list = mem_manager_->MallocContinuousMemFromMemPool(total_size, size_list);
if (device_ptr_list.size() == 0) {
return false;
}
if (addr_list.size() != device_ptr_list.size()) {
MS_LOG(EXCEPTION) << "The size of device list is not equal to the size of address list.";
}
for (size_t i = 0; i < addr_list.size(); i++) {
MS_EXCEPTION_IF_NULL(device_ptr_list[i]);
MS_EXCEPTION_IF_NULL(addr_list[i]);
addr_list[i]->ptr_ = device_ptr_list[i];
addr_list[i]->from_mem_pool_ = true;
}
return true;
} }
void GPUDeviceContext::SetOperatorInfo(const std::vector<CNodePtr> &nodes) const { void GPUDeviceContext::SetOperatorInfo(const std::vector<CNodePtr> &nodes) const {

View File

@ -38,9 +38,9 @@ class GPUDeviceContext : public DeviceContext {
// Release device memory, stream, cudnn and cublas handle, etc. // Release device memory, stream, cudnn and cublas handle, etc.
void Destroy() override; void Destroy() override;
bool AllocateMemory(const DeviceAddressPtr &address, size_t size) const override; bool AllocateMemory(DeviceAddress *const &address, size_t size) const override;
void FreeMemory(const DeviceAddressPtr &address) const override; void FreeMemory(DeviceAddress *const &address) const override;
bool AllocateContinuousMemory(const DeviceAddressPtrList &addr_list, size_t total_size, bool AllocateContinuousMemory(const std::vector<DeviceAddress *> &addr_list, size_t total_size,
const std::vector<size_t> &size_list) const override; const std::vector<size_t> &size_list) const override;
void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const override; void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const override;

View File

@ -33,6 +33,9 @@ endif()
set_property(SOURCE ${CORE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_CORE) set_property(SOURCE ${CORE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_CORE)
add_library(mindspore_core STATIC ${CORE_SRC_LIST}) add_library(mindspore_core STATIC ${CORE_SRC_LIST})
target_link_libraries(mindspore_core PRIVATE mindspore_gvar) target_link_libraries(mindspore_core PRIVATE mindspore_gvar)
if(NOT(COMPILE_LITE))
target_link_libraries(mindspore_core PRIVATE mindrt_mid)
endif()
if(USE_GLOG) if(USE_GLOG)
target_link_libraries(mindspore_core PRIVATE mindspore::glog) target_link_libraries(mindspore_core PRIVATE mindspore::glog)

View File

@ -58,11 +58,19 @@ struct OpContext {
std::vector<Promise<int>> *results_; std::vector<Promise<int>> *results_;
const void *kernel_call_back_before_; const void *kernel_call_back_before_;
const void *kernel_call_back_after_; const void *kernel_call_back_after_;
void SetFailed(int32_t code) { void SetFailed(int32_t code) {
for (auto promise : *results_) { for (auto promise : *results_) {
promise.SetFailed(code); promise.SetFailed(code);
} }
} }
void SetSuccess(int32_t code) {
for (auto promise : *results_) {
promise.SetValue(code);
}
}
void SetResult(size_t index, int value) { results_->at(index).SetValue(value); } void SetResult(size_t index, int value) { results_->at(index).SetValue(value); }
}; };