!21091 unified runtime codedex fixed

Merge pull request !21091 from limingqi107/bug_fix4
This commit is contained in:
i-robot 2021-07-30 10:35:43 +00:00 committed by Gitee
commit 0816f95653
33 changed files with 181 additions and 162 deletions

View File

@ -277,6 +277,7 @@ void DynamicMemPoolBestFit::ReleaseDeviceRes() {
if (!FreeDeviceMem(device_addr)) {
MS_LOG(EXCEPTION) << "Free device memory[" << device_addr << "] error.";
}
device_addr = nullptr;
}
}

View File

@ -319,10 +319,10 @@ class KernelGraph : public FuncGraph {
void InsertToSendRecvPair(const CNodePtr &allreduce, const std::pair<CNodePtr, CNodePtr> &send_recv_pair) {
allreduce_to_send_recv_pairs_[allreduce] = send_recv_pair;
}
std::unordered_map<CNodePtr, std::pair<CNodePtr, CNodePtr>> &allreduce_from_send_recv_pairs() {
const std::unordered_map<CNodePtr, std::pair<CNodePtr, CNodePtr>> &allreduce_from_send_recv_pairs() const {
return allreduce_from_send_recv_pairs_;
}
std::unordered_map<CNodePtr, std::pair<CNodePtr, CNodePtr>> &allreduce_to_send_recv_pairs() {
const std::unordered_map<CNodePtr, std::pair<CNodePtr, CNodePtr>> &allreduce_to_send_recv_pairs() const {
return allreduce_to_send_recv_pairs_;
}

View File

@ -65,7 +65,7 @@ class DumpJsonParser {
void ClearGraph() { graphs_.clear(); }
void SaveGraph(session::KernelGraph *graph) { (void)graphs_.emplace_back(graph); }
std::vector<session::KernelGraph *> &graphs() { return graphs_; }
const std::vector<session::KernelGraph *> &graphs() const { return graphs_; }
private:
DumpJsonParser() = default;

View File

@ -158,7 +158,7 @@ void AscendDeviceAddress::SyncStream() const {
MS_LOG(DEBUG) << "Finish!";
}
bool AscendDeviceAddress::SyncDeviceToHost(size_t size, void *host_ptr) const {
bool AscendDeviceAddress::SyncDeviceToHost(size_t size, void *const host_ptr) const {
MS_EXCEPTION_IF_NULL(host_ptr);
SyncStream();
SyncMemory(host_ptr, ptr_, size, RT_MEMCPY_DEVICE_TO_HOST);

View File

@ -42,7 +42,7 @@ class AscendDeviceAddress : public DeviceAddress {
const KernelWithIndex &node_index)
: DeviceAddress(ptr, size, format, type_id, node_index) {}
~AscendDeviceAddress() override;
bool SyncDeviceToHost(size_t size, void *host_ptr) const override;
bool SyncDeviceToHost(size_t size, void *const host_ptr) const override;
bool SyncHostToDevice(size_t size, const void *host_ptr) const override;
bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override;
bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,

View File

@ -106,7 +106,7 @@ void IntToLong(void *dst, const void *src, size_t elem_num) {
}
}
void ConvertSameType(void *dst, const void *src, size_t size, TypeId type) {
void ConvertSameType(void *const dst, const void *src, size_t size, TypeId type) {
if (dst == nullptr || src == nullptr) {
return;
}

View File

@ -31,7 +31,7 @@ void ShortToInt(void *dst, const void *src, size_t elem_num);
void IntToShort(void *dst, const void *src, size_t elem_num);
void LongToInt(void *dst, const void *src, size_t elem_num);
void IntToLong(void *dst, const void *src, size_t elem_num);
void ConvertSameType(void *dst, const void *src, size_t size, TypeId type);
void ConvertSameType(void *const dst, const void *src, size_t size, TypeId type);
template <typename T>
void ConvertSameType(T *dst, const T *src, size_t elem_num) {

View File

@ -76,6 +76,8 @@ bool CPUDeviceAddress::SyncDeviceToHost(const ShapeVector &, size_t size, TypeId
} else if (ret_code != EOK) {
MS_LOG(ERROR) << "Failed to copy tensor!";
return false;
} else {
return true;
}
} else if (type == kNumberTypeFloat16 && type_id_ == kNumberTypeFloat32) {
FloatToHalf(host_ptr, ptr_, size >> 1);

View File

@ -70,8 +70,8 @@ class KernelInfo : public KernelInfoDevice {
uint32_t graph_id() const { return graph_id_; }
bool operator==(const KernelInfo &other) const;
bool is_feature_map() const { return is_feature_map_; }
std::vector<std::shared_ptr<DeviceAddress>> &output_address_list() { return output_address_list_; }
std::vector<std::shared_ptr<DeviceAddress>> &workspace_address_list() { return workspace_address_list_; }
const std::vector<std::shared_ptr<DeviceAddress>> &output_address_list() const { return output_address_list_; }
const std::vector<std::shared_ptr<DeviceAddress>> &workspace_address_list() const { return workspace_address_list_; }
private:
bool is_feature_map_;

View File

@ -114,7 +114,7 @@ bool IsGatherActor(const AnfNodePtr &front_node,
return false;
}
bool Copy(DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor) {
bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor) {
MS_EXCEPTION_IF_NULL(dst_device_tensor);
MS_EXCEPTION_IF_NULL(src_device_tensor);
if (src_device_tensor->GetSize() != dst_device_tensor->GetSize()) {

View File

@ -89,7 +89,7 @@ bool IsGatherActor(const AnfNodePtr &front_node,
const std::unordered_map<std::string, OpActor<DeviceTensor> *> &actor_name_to_actor);
// Copy data from src_device_tensor to dst_device_tensor.
bool Copy(DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor);
bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor);
} // namespace runtime
} // namespace mindspore

View File

@ -38,7 +38,7 @@ void CopyActor::Init() {
}
}
void CopyActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) {
void CopyActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
input_op_datas_[sequential_num].emplace_back(input_data);
@ -49,7 +49,7 @@ void CopyActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTens
}
}
void CopyActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) {
void CopyActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
input_op_controls_[sequential_num].emplace_back(input_control);
@ -60,17 +60,17 @@ void CopyActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *contex
}
}
void CopyActor::SendMemoryAllocReq(OpContext<DeviceTensor> *context) {
void CopyActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &output_device_tensor_, output_device_context_,
context, GetAID());
}
void CopyActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) {
void CopyActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensor_, input_device_context_, context);
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &output_device_tensor_, output_device_context_, context);
}
void CopyActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
void CopyActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(output_device_tensor_[0]);
MS_EXCEPTION_IF_NULL(input_device_tensor_[0]);
@ -96,7 +96,7 @@ void CopyActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
SendOutput(context);
}
bool CopyActor::CheckCopyCondition(OpContext<DeviceTensor> *context) const {
bool CopyActor::CheckCopyCondition(OpContext<DeviceTensor> *const context) const {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
@ -120,7 +120,7 @@ bool CopyActor::CheckCopyCondition(OpContext<DeviceTensor> *context) const {
return true;
}
void CopyActor::FetchDeviceTensor(OpContext<DeviceTensor> *context) {
void CopyActor::FetchDeviceTensor(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(input_device_context_);
@ -156,7 +156,7 @@ void CopyActor::FetchDeviceTensor(OpContext<DeviceTensor> *context) {
}
}
void CopyActor::SendOutput(OpContext<DeviceTensor> *context) const {
void CopyActor::SendOutput(OpContext<DeviceTensor> *const context) const {
MS_EXCEPTION_IF_NULL(context);
// No output.
if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0)) {
@ -179,7 +179,7 @@ void CopyActor::SendOutput(OpContext<DeviceTensor> *context) const {
}
}
void CopyActor::EraseInput(OpContext<DeviceTensor> *context) {
void CopyActor::EraseInput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
auto ret = input_op_datas_.erase(context->sequential_num_);

View File

@ -41,34 +41,36 @@ class CopyActor : public MemoryAwareActor {
memory_manager_aid_(memory_manager_aid),
input_datas_num_(0),
input_controls_num_(0),
input_device_context_(nullptr),
output_device_context_(nullptr),
output_(nullptr) {}
~CopyActor() override = default;
void Init() override;
// The copy actor run when receive the input data.
void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) override;
void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;
// The copy actor run when receive the input control.
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override;
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;
// The memory related operation interface.
void SendMemoryAllocReq(OpContext<DeviceTensor> *context) override;
void SendMemoryFreeReq(OpContext<DeviceTensor> *context) override;
void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override;
void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;
// The copy processing after memory alloc finished.
void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override;
void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;
private:
friend class GraphScheduler;
// Check whether satisfy the condition for copy.
bool CheckCopyCondition(OpContext<DeviceTensor> *context) const;
bool CheckCopyCondition(OpContext<DeviceTensor> *const context) const;
// Fetch the device tensor for copy.
void FetchDeviceTensor(OpContext<DeviceTensor> *context);
void FetchDeviceTensor(OpContext<DeviceTensor> *const context);
// Send output data and output controls when finish copy.
void SendOutput(OpContext<DeviceTensor> *context) const;
void SendOutput(OpContext<DeviceTensor> *const context) const;
// Erase input data and input controls when finish copy.
void EraseInput(OpContext<DeviceTensor> *context);
void EraseInput(OpContext<DeviceTensor> *const context);
// The id of memory manager actor. Send message to it for alloc and free memory during the copy.
const AID memory_manager_aid_;

View File

@ -35,7 +35,7 @@ void DataSourceActor::Init() {
}
}
void DataSourceActor::FetchData(OpContext<DeviceTensor> *context) {
void DataSourceActor::FetchData(OpContext<DeviceTensor> *const context) {
MS_LOG(INFO) << "Data source actor(" << GetAID().Name() << ") fetches data.";
MS_EXCEPTION_IF_NULL(context);
// Pop the data of last time.
@ -53,7 +53,7 @@ void DataSourceActor::FetchData(OpContext<DeviceTensor> *context) {
SendMemoryAllocReq(context);
}
void DataSourceActor::SendOutput(OpContext<DeviceTensor> *context) {
void DataSourceActor::SendOutput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
// No output.
if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0) &&
@ -124,17 +124,17 @@ void DeviceQueueDataSourceActor::FillDataBuffer() {
buffers_.push(device_tensors);
}
void DeviceQueueDataSourceActor::SendMemoryAllocReq(OpContext<DeviceTensor> *context) {
void DeviceQueueDataSourceActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
auto &device_tensors = buffers_.back();
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors, device_context_, context, GetAID());
}
void DeviceQueueDataSourceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) {
void DeviceQueueDataSourceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
auto &device_tensors = buffers_.front();
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors, device_context_, context);
}
void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(device_context_);
if (buffers_.size() == 0) {
@ -177,16 +177,16 @@ void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *co
SendOutput(context);
}
void DeviceQueueDataSourceActor::SendDebugReq(OpContext<DeviceTensor> *context) {
void DeviceQueueDataSourceActor::SendDebugReq(OpContext<DeviceTensor> *const context) {
Async(*debug_aid_, &DebugActor::Debug, data_kernel_, &launch_info_, device_context_, context, &GetAID());
}
void DeviceQueueDataSourceActor::OnDebugFinish(OpContext<DeviceTensor> *context) {
void DeviceQueueDataSourceActor::OnDebugFinish(OpContext<DeviceTensor> *const context) {
SendMemoryFreeReq(context);
SendOutput(context);
}
void DeviceQueueDataSourceActor::SendResult(OpContext<DeviceTensor> *context) {
void DeviceQueueDataSourceActor::SendResult(OpContext<DeviceTensor> *const context) {
for (const auto &result_arrow : output_result_arrows_) {
MS_EXCEPTION_IF_NULL(result_arrow);
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, data_kernel_, result_arrow->from_output_index_,
@ -194,7 +194,7 @@ void DeviceQueueDataSourceActor::SendResult(OpContext<DeviceTensor> *context) {
}
}
void DeviceQueueDataSourceActor::SendRecorderInfo(OpContext<DeviceTensor> *context) {
void DeviceQueueDataSourceActor::SendRecorderInfo(OpContext<DeviceTensor> *const context) {
if (recorder_aid_ != nullptr) {
Async(*recorder_aid_, &RecorderActor::RecordInfo, data_kernel_->fullname_with_scope(), &launch_info_,
device_context_, context);
@ -213,7 +213,7 @@ void HostQueueDataSourceActor::FillDataBuffer() {
buffers_.push(device_tensors);
}
void HostQueueDataSourceActor::SendMemoryAllocReq(OpContext<DeviceTensor> *context) {
void HostQueueDataSourceActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
auto &device_tensors = buffers_.back();
if (IsSameDeviceType()) {
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors, device_contexts_[0], context,
@ -224,7 +224,7 @@ void HostQueueDataSourceActor::SendMemoryAllocReq(OpContext<DeviceTensor> *conte
}
}
void HostQueueDataSourceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) {
void HostQueueDataSourceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
auto &device_tensors = buffers_.front();
if (IsSameDeviceType()) {
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors, device_contexts_[0], context);
@ -233,7 +233,7 @@ void HostQueueDataSourceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *contex
}
}
void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
if (buffers_.size() == 0) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty.");
@ -283,7 +283,7 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *cont
SendOutput(context);
}
void HostQueueDataSourceActor::SendResult(OpContext<DeviceTensor> *context) {
void HostQueueDataSourceActor::SendResult(OpContext<DeviceTensor> *const context) {
for (const auto &result_arrow : output_result_arrows_) {
MS_EXCEPTION_IF_NULL(result_arrow);
if (IntToSize(result_arrow->from_output_index_) >= data_nodes_.size()) {

View File

@ -53,13 +53,13 @@ class DataSourceActor : public DebugAwareActor {
void Init() override;
// The process entry of data processing.
void FetchData(OpContext<DeviceTensor> *context);
void FetchData(OpContext<DeviceTensor> *const context);
// The memory related operation interface.
void SendMemoryAllocReq(OpContext<DeviceTensor> *context) override{};
void SendMemoryFreeReq(OpContext<DeviceTensor> *context) override{};
void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override{};
void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override{};
// Copy data from data source to the device tensor buffer of actor after memory alloc finished.
void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override{};
void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override{};
protected:
friend class GraphScheduler;
@ -68,13 +68,13 @@ class DataSourceActor : public DebugAwareActor {
virtual void FillDataBuffer() = 0;
// Send output result of graph output to output actor.
virtual void SendResult(OpContext<DeviceTensor> *context) = 0;
virtual void SendResult(OpContext<DeviceTensor> *const context) = 0;
// Send recorder info to recorder actor, only the device queue data source actor need.
virtual void SendRecorderInfo(OpContext<DeviceTensor> *context) {}
virtual void SendRecorderInfo(OpContext<DeviceTensor> *const context) {}
// Send output to downstream actors to trigger computing after fetching data finished.
void SendOutput(OpContext<DeviceTensor> *context);
void SendOutput(OpContext<DeviceTensor> *const context);
// The output result arrows of graph output.
std::vector<DataArrowPtr> output_result_arrows_;
@ -105,17 +105,17 @@ class DeviceQueueDataSourceActor : public DataSourceActor {
void Init() override;
void SendMemoryAllocReq(OpContext<DeviceTensor> *context) override;
void SendMemoryFreeReq(OpContext<DeviceTensor> *context) override;
void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override;
void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override;
void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;
void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;
void SendDebugReq(OpContext<DeviceTensor> *context) override;
void OnDebugFinish(OpContext<DeviceTensor> *context) override;
void SendDebugReq(OpContext<DeviceTensor> *const context) override;
void OnDebugFinish(OpContext<DeviceTensor> *const context) override;
protected:
void FillDataBuffer() override;
void SendResult(OpContext<DeviceTensor> *context) override;
void SendRecorderInfo(OpContext<DeviceTensor> *context) override;
void SendResult(OpContext<DeviceTensor> *const context) override;
void SendRecorderInfo(OpContext<DeviceTensor> *const context) override;
private:
friend class GraphScheduler;
@ -138,15 +138,15 @@ class HostQueueDataSourceActor : public DataSourceActor {
: DataSourceActor(name, buffer_capacity, memory_manager_aid, debug_aid, recorder_aid), host_queue_(host_queue) {}
~HostQueueDataSourceActor() override = default;
void SendMemoryAllocReq(OpContext<DeviceTensor> *context) override;
void SendMemoryFreeReq(OpContext<DeviceTensor> *context) override;
void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override;
void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override;
void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;
void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;
size_t FetchDataNodePosition(const AnfNodePtr &data_node) const;
protected:
void FillDataBuffer() override;
void SendResult(OpContext<DeviceTensor> *context) override;
void SendResult(OpContext<DeviceTensor> *const context) override;
private:
friend class GraphScheduler;

View File

@ -31,7 +31,8 @@ namespace mindspore {
namespace runtime {
void DebugActor::Debug(const AnfNodePtr &node, const KernelLaunchInfo *launch_info_,
const DeviceContext *device_context, OpContext<DeviceTensor> *op_context, const AID *from_aid) {
const DeviceContext *device_context, OpContext<DeviceTensor> *const op_context,
const AID *from_aid) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(op_context);
@ -69,7 +70,7 @@ void DebugActor::Debug(const AnfNodePtr &node, const KernelLaunchInfo *launch_in
Async(*from_aid, &DebugAwareActor::OnDebugFinish, op_context);
}
void DebugActor::DebugOnStepEnd(OpContext<DeviceTensor> *op_context, const AID *from_aid) {
void DebugActor::DebugOnStepEnd(OpContext<DeviceTensor> *const op_context, const AID *from_aid) {
MS_EXCEPTION_IF_NULL(op_context);
MS_EXCEPTION_IF_NULL(from_aid);

View File

@ -35,10 +35,10 @@ class DebugActor : public ActorBase {
// The debug of each node.
void Debug(const AnfNodePtr &node, const KernelLaunchInfo *launch_info_, const DeviceContext *device_context,
OpContext<DeviceTensor> *op_context, const AID *from_aid);
OpContext<DeviceTensor> *const op_context, const AID *from_aid);
// The debug on step end.
void DebugOnStepEnd(OpContext<DeviceTensor> *op_context, const AID *from_aid);
void DebugOnStepEnd(OpContext<DeviceTensor> *const op_context, const AID *from_aid);
private:
// class members

View File

@ -27,8 +27,8 @@ class DebugAwareActor : public MemoryAwareActor {
public:
explicit DebugAwareActor(const std::string &name) : MemoryAwareActor(name) {}
virtual ~DebugAwareActor() = default;
virtual void SendDebugReq(OpContext<DeviceTensor> *context) {}
virtual void OnDebugFinish(OpContext<DeviceTensor> *context) {}
virtual void SendDebugReq(OpContext<DeviceTensor> *const context) {}
virtual void OnDebugFinish(OpContext<DeviceTensor> *const context) {}
};
} // namespace runtime
} // namespace mindspore

View File

@ -74,7 +74,7 @@ void KernelActor::Init() {
}
}
void KernelActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) {
void KernelActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
input_op_datas_[sequential_num].emplace_back(input_data);
@ -100,7 +100,7 @@ void KernelActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTe
}
}
void KernelActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) {
void KernelActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
input_op_controls_[sequential_num].emplace_back(input_control);
@ -178,7 +178,7 @@ void FreeMemory(const std::vector<DeviceTensor *> &free_list, const DeviceContex
}
} // namespace
void KernelActor::SendMemoryAllocReq(OpContext<DeviceTensor> *context) {
void KernelActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
running_dependent_msg_num_ = 1;
if (strategy_ == GraphExecutionStrategy::kPipeline) {
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &memory_alloc_list_, device_context_, context,
@ -188,7 +188,7 @@ void KernelActor::SendMemoryAllocReq(OpContext<DeviceTensor> *context) {
}
}
void KernelActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) {
void KernelActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
if (strategy_ == GraphExecutionStrategy::kPipeline) {
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &memory_free_list_, device_context_, context);
} else {
@ -196,7 +196,7 @@ void KernelActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) {
}
}
void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(kernel_);
MS_EXCEPTION_IF_NULL(device_context_);
@ -224,7 +224,7 @@ void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
PostLaunchKernel(context);
}
void KernelActor::SendDebugReq(OpContext<DeviceTensor> *context) {
void KernelActor::SendDebugReq(OpContext<DeviceTensor> *const context) {
running_dependent_msg_num_ = 1;
Async(*debug_aid_, &DebugActor::Debug, kernel_, &launch_info_, device_context_, context, &GetAID());
}
@ -234,7 +234,7 @@ void KernelActor::OnDebugFinish(OpContext<DeviceTensor> *context) {
PostLaunchKernel(context);
}
bool KernelActor::CheckLaunchCondition(OpContext<DeviceTensor> *context) const {
bool KernelActor::CheckLaunchCondition(OpContext<DeviceTensor> *const context) const {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
@ -276,7 +276,8 @@ void KernelActor::PushInputDeviceTensor(const std::vector<TensorPtr> *input_tens
}
}
void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) {
void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(input_data);
if ((input_data->data_ == nullptr) || (input_data->data_->DeviceType() == device_context_->GetDeviceAddressType())) {
return;
@ -312,7 +313,7 @@ void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
memory_free_list_[input_data->index_] = copy_input_device_tensors_[input_data->index_].get();
}
void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
MS_EXCEPTION_IF_NULL(device_context_);
@ -392,7 +393,7 @@ void KernelActor::PreLaunchKernel(OpContext<DeviceTensor> *) {
}
}
void KernelActor::PostLaunchKernel(OpContext<DeviceTensor> *context) {
void KernelActor::PostLaunchKernel(OpContext<DeviceTensor> *const context) {
running_dependent_msg_num_ = SizeToInt(input_datas_num_ + input_controls_num_);
// The input is invalid and needs to be erased when finish kernel launch.
@ -408,7 +409,7 @@ void KernelActor::PostLaunchKernel(OpContext<DeviceTensor> *context) {
SendOutput(context);
}
void KernelActor::SendOutput(OpContext<DeviceTensor> *context) const {
void KernelActor::SendOutput(OpContext<DeviceTensor> *const context) const {
MS_EXCEPTION_IF_NULL(context);
if (strategy_ == GraphExecutionStrategy::kStep) {
return;
@ -449,7 +450,7 @@ void KernelActor::SendOutput(OpContext<DeviceTensor> *context) const {
}
}
void KernelActor::EraseInput(OpContext<DeviceTensor> *context) {
void KernelActor::EraseInput(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
if (input_datas_num_ != 0) {
auto ret = input_op_datas_.erase(context->sequential_num_);

View File

@ -65,45 +65,45 @@ class KernelActor : public DebugAwareActor {
bool IsActive(int msg_num) override { return msg_num >= running_dependent_msg_num_ ? true : false; }
// The kernel actor run when receive the input data.
void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) override;
void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;
// The kernel actor run when receive the input control.
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override;
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;
// The kernel actor run when receive the input control and input tensors, used in step mode.
void RunOpControlWithInputTensor(AID *const input_control, OpContext<DeviceTensor> *const context,
const std::vector<TensorPtr> *input_tensors);
// The memory related operation interface.
void SendMemoryAllocReq(OpContext<DeviceTensor> *context) override;
void SendMemoryFreeReq(OpContext<DeviceTensor> *context) override;
void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override;
void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;
// The callback after memory alloc finished.
void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override;
void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;
// The debug related operation interface.
void SendDebugReq(OpContext<DeviceTensor> *context) override;
void SendDebugReq(OpContext<DeviceTensor> *const context) override;
// The callback after debug finished.
void OnDebugFinish(OpContext<DeviceTensor> *context) override;
void OnDebugFinish(OpContext<DeviceTensor> *const context) override;
private:
friend class GraphScheduler;
// Check whether satisfy the condition for launch.
bool CheckLaunchCondition(OpContext<DeviceTensor> *context) const;
bool CheckLaunchCondition(OpContext<DeviceTensor> *const context) const;
// Fetch the device tensor for launch.
void FetchInputDeviceTensor(OpContext<DeviceTensor> *context);
void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context);
void FetchOutputDeviceTensor();
void CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context);
void CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *const context);
// In step mode, push the input tensors which contain valid device address into input_device_tensors_ directly.
void PushInputDeviceTensor(const std::vector<TensorPtr> *input_tensors);
// The processing before kernel launch: update the info of kernel launch.
void PreLaunchKernel(OpContext<DeviceTensor> *context);
void PreLaunchKernel(OpContext<DeviceTensor> *const context);
// The processing after kernel launch: 1.erase input, 2.free memory, 3.send output.
void PostLaunchKernel(OpContext<DeviceTensor> *context);
void PostLaunchKernel(OpContext<DeviceTensor> *const context);
// Send output data and output controls when finish kernel launch.
void SendOutput(OpContext<DeviceTensor> *context) const;
void SendOutput(OpContext<DeviceTensor> *const context) const;
// Erase input data and input controls when finish kernel launch.
void EraseInput(OpContext<DeviceTensor> *context);
void EraseInput(OpContext<DeviceTensor> *const context);
// The info of kernel.
CNodePtr kernel_;

View File

@ -27,8 +27,8 @@
namespace mindspore {
namespace runtime {
namespace {
void FetchContinuousMemoryInfo(const CNodePtr &node, std::vector<DeviceTensorPtr> *addr_list,
std::vector<size_t> *size_list, size_t *total_size, bool is_input) {
void FetchContinuousMemoryInfo(const CNodePtr &node, std::vector<DeviceTensorPtr> *const addr_list,
std::vector<size_t> *const size_list, size_t *const total_size, bool is_input) {
MS_EXCEPTION_IF_NULL(node);
const auto &kernel_mod = AnfAlgo::GetKernelMod(node);
MS_EXCEPTION_IF_NULL(kernel_mod);
@ -82,7 +82,7 @@ void LoopCountActor::Init() {
}
}
void LoopCountActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) {
void LoopCountActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto sequential_num = context->sequential_num_;
input_op_controls_[sequential_num].emplace_back(input_control);
@ -91,16 +91,16 @@ void LoopCountActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *c
}
}
void LoopCountActor::SendDebugReq(OpContext<DeviceTensor> *context) {
void LoopCountActor::SendDebugReq(OpContext<DeviceTensor> *const context) {
Async(*debug_aid_, &DebugActor::DebugOnStepEnd, context, &GetAID());
}
void LoopCountActor::OnDebugFinish(OpContext<DeviceTensor> *context) {
void LoopCountActor::OnDebugFinish(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
SendOutput(context);
}
void LoopCountActor::IncreaseLoopCount(OpContext<DeviceTensor> *context) {
void LoopCountActor::IncreaseLoopCount(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto sequential_num = context->sequential_num_;
auto ret = input_op_controls_.erase(sequential_num);
@ -123,7 +123,7 @@ void LoopCountActor::IncreaseLoopCount(OpContext<DeviceTensor> *context) {
SendOutput(context);
}
void LoopCountActor::SendOutput(OpContext<DeviceTensor> *context) {
void LoopCountActor::SendOutput(OpContext<DeviceTensor> *const context) {
// Send recorder info.
if (recorder_aid_ != nullptr) {
Async(*recorder_aid_, &RecorderActor::RecordOnStepEnd, context);
@ -131,7 +131,7 @@ void LoopCountActor::SendOutput(OpContext<DeviceTensor> *context) {
SendMemoryAllocReq(context);
}
void LoopCountActor::SendMemoryAllocReq(OpContext<DeviceTensor> *context) {
void LoopCountActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
if (current_count_ == loop_count_) {
// Need wait MemoryManagerActor running finished to avoid the illegal memory timing problem before
// LoopCountActor exits, because other processors which are not in actor also will allocate or free memory.
@ -145,7 +145,7 @@ void LoopCountActor::SendMemoryAllocReq(OpContext<DeviceTensor> *context) {
}
}
void LoopCountActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
void LoopCountActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
// Send loop count to output actor.
Async(output_aid_, &OutputActor::CollectLoopCount, current_count_, context);
@ -166,7 +166,7 @@ void LoopCountActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
}
}
bool LoopCountActor::CheckLoopCountIncreaseCondition(OpContext<DeviceTensor> *context) {
bool LoopCountActor::CheckLoopCountIncreaseCondition(OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto sequential_num = context->sequential_num_;

View File

@ -50,25 +50,25 @@ class LoopCountActor : public DebugAwareActor {
void Init() override;
// The loop count actor run when receive the input control.
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override;
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;
// The memory related operation interface.
void SendMemoryAllocReq(OpContext<DeviceTensor> *context) override;
void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override;
// The callback after memory alloc finished.
void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override;
void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;
// The debug related operation interface.
void SendDebugReq(OpContext<DeviceTensor> *context) override;
void SendDebugReq(OpContext<DeviceTensor> *const context) override;
// The callback after debug finished.
void OnDebugFinish(OpContext<DeviceTensor> *context) override;
void OnDebugFinish(OpContext<DeviceTensor> *const context) override;
private:
friend class GraphScheduler;
void IncreaseLoopCount(OpContext<DeviceTensor> *context);
void SendOutput(OpContext<DeviceTensor> *context);
void IncreaseLoopCount(OpContext<DeviceTensor> *const context);
void SendOutput(OpContext<DeviceTensor> *const context);
bool CheckLoopCountIncreaseCondition(OpContext<DeviceTensor> *context);
bool CheckLoopCountIncreaseCondition(OpContext<DeviceTensor> *const context);
// The loop count is constant, the current count is increased after each step running finished.
size_t loop_count_;
size_t current_count_;

View File

@ -29,9 +29,9 @@ class MemoryAwareActor : public OpActor<DeviceTensor> {
public:
explicit MemoryAwareActor(std::string name) : OpActor(name) {}
virtual ~MemoryAwareActor() = default;
virtual void SendMemoryAllocReq(OpContext<DeviceTensor> *context) {}
virtual void SendMemoryFreeReq(OpContext<DeviceTensor> *context) {}
virtual void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {}
virtual void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {}
virtual void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {}
virtual void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {}
friend class GraphScheduler;
};

View File

@ -22,8 +22,9 @@
namespace mindspore {
namespace runtime {
void MemoryManagerActor::AllocateMemory(std::vector<DeviceTensor *> *alloc_list, const DeviceContext *device_context,
OpContext<DeviceTensor> *op_context, const AID from_aid) {
void MemoryManagerActor::AllocateMemory(const std::vector<DeviceTensor *> *alloc_list,
const DeviceContext *device_context, OpContext<DeviceTensor> *const op_context,
const AID from_aid) {
MS_EXCEPTION_IF_NULL(alloc_list);
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(op_context);
@ -46,11 +47,11 @@ void MemoryManagerActor::AllocateMemory(std::vector<DeviceTensor *> *alloc_list,
Async(from_aid, &MemoryAwareActor::OnMemoryAllocFinish, op_context);
}
void MemoryManagerActor::AllocateContinuousMemory(std::vector<std::vector<DeviceTensorPtr>> *alloc_list_list,
std::vector<std::vector<size_t>> *size_list_list,
std::vector<size_t> *total_size_list,
std::vector<const DeviceContext *> *device_contexts,
OpContext<DeviceTensor> *op_context, const AID from_aid) {
void MemoryManagerActor::AllocateContinuousMemory(const std::vector<std::vector<DeviceTensorPtr>> *alloc_list_list,
const std::vector<std::vector<size_t>> *size_list_list,
const std::vector<size_t> *total_size_list,
const std::vector<const DeviceContext *> *device_contexts,
OpContext<DeviceTensor> *const op_context, const AID from_aid) {
MS_EXCEPTION_IF_NULL(alloc_list_list);
MS_EXCEPTION_IF_NULL(size_list_list);
MS_EXCEPTION_IF_NULL(total_size_list);
@ -81,9 +82,9 @@ void MemoryManagerActor::AllocateContinuousMemory(std::vector<std::vector<Device
Async(from_aid, &MemoryAwareActor::OnMemoryAllocFinish, op_context);
}
void MemoryManagerActor::AllocateBatchMemory(std::vector<DeviceTensor *> *alloc_list,
std::vector<const DeviceContext *> *device_contexts,
OpContext<DeviceTensor> *op_context, const AID from_aid) {
void MemoryManagerActor::AllocateBatchMemory(const std::vector<DeviceTensor *> *alloc_list,
const std::vector<const DeviceContext *> *device_contexts,
OpContext<DeviceTensor> *const op_context, const AID from_aid) {
MS_EXCEPTION_IF_NULL(alloc_list);
MS_EXCEPTION_IF_NULL(device_contexts);
MS_EXCEPTION_IF_NULL(op_context);
@ -114,7 +115,7 @@ void MemoryManagerActor::AllocateBatchMemory(std::vector<DeviceTensor *> *alloc_
Async(from_aid, &MemoryAwareActor::OnMemoryAllocFinish, op_context);
}
void MemoryManagerActor::FreeMemory(std::vector<DeviceTensor *> *free_list, const DeviceContext *device_context,
void MemoryManagerActor::FreeMemory(const std::vector<DeviceTensor *> *free_list, const DeviceContext *device_context,
OpContext<DeviceTensor> *) {
MS_EXCEPTION_IF_NULL(free_list);
MS_EXCEPTION_IF_NULL(device_context);
@ -135,9 +136,9 @@ void MemoryManagerActor::FreeMemory(std::vector<DeviceTensor *> *free_list, cons
}
}
void MemoryManagerActor::FreeBatchMemory(std::vector<DeviceTensor *> *free_list,
std::vector<const DeviceContext *> *device_contexts,
OpContext<DeviceTensor> *op_context) {
void MemoryManagerActor::FreeBatchMemory(const std::vector<DeviceTensor *> *free_list,
const std::vector<const DeviceContext *> *device_contexts,
OpContext<DeviceTensor> *const op_context) {
MS_EXCEPTION_IF_NULL(free_list);
MS_EXCEPTION_IF_NULL(device_contexts);
MS_EXCEPTION_IF_NULL(op_context);
@ -166,7 +167,7 @@ void MemoryManagerActor::FreeBatchMemory(std::vector<DeviceTensor *> *free_list,
}
}
void MemoryManagerActor::Wait(OpContext<DeviceTensor> *op_context, const AID from_aid) {
void MemoryManagerActor::Wait(OpContext<DeviceTensor> *const op_context, const AID from_aid) {
// Call back to the from actor to process.
Async(from_aid, &MemoryAwareActor::OnMemoryAllocFinish, op_context);
}

View File

@ -36,27 +36,30 @@ class MemoryManagerActor : public ActorBase {
~MemoryManagerActor() override = default;
// The process entry of memory alloc.
void AllocateMemory(std::vector<DeviceTensor *> *alloc_list, const DeviceContext *device_context,
OpContext<DeviceTensor> *op_context, const AID from_aid);
void AllocateMemory(const std::vector<DeviceTensor *> *alloc_list, const DeviceContext *device_context,
OpContext<DeviceTensor> *const op_context, const AID from_aid);
// The process entry of continuous memory alloc, the size of alloc_list_list, size_list_list, total_size_list and
// device_contexts must be equal.
void AllocateContinuousMemory(std::vector<std::vector<DeviceTensorPtr>> *alloc_list_list,
std::vector<std::vector<size_t>> *size_list_list, std::vector<size_t> *total_size_list,
std::vector<const DeviceContext *> *device_contexts,
OpContext<DeviceTensor> *op_context, const AID from_aid);
void AllocateContinuousMemory(const std::vector<std::vector<DeviceTensorPtr>> *alloc_list_list,
const std::vector<std::vector<size_t>> *size_list_list,
const std::vector<size_t> *total_size_list,
const std::vector<const DeviceContext *> *device_contexts,
OpContext<DeviceTensor> *const op_context, const AID from_aid);
// device_contexts is from different device, the size of device_contexts must be equal to the alloc_list.
void AllocateBatchMemory(std::vector<DeviceTensor *> *alloc_list, std::vector<const DeviceContext *> *device_contexts,
OpContext<DeviceTensor> *op_context, const AID from_aid);
void AllocateBatchMemory(const std::vector<DeviceTensor *> *alloc_list,
const std::vector<const DeviceContext *> *device_contexts,
OpContext<DeviceTensor> *const op_context, const AID from_aid);
// The process entry of memory free.
void FreeMemory(std::vector<DeviceTensor *> *free_list, const DeviceContext *device_context,
OpContext<DeviceTensor> *op_context);
void FreeMemory(const std::vector<DeviceTensor *> *free_list, const DeviceContext *device_context,
OpContext<DeviceTensor> *const op_context);
// device_contexts is from different device, the size of device_contexts must be equal to the free_list.
void FreeBatchMemory(std::vector<DeviceTensor *> *free_list, std::vector<const DeviceContext *> *device_contexts,
OpContext<DeviceTensor> *op_context);
void FreeBatchMemory(const std::vector<DeviceTensor *> *free_list,
const std::vector<const DeviceContext *> *device_contexts,
OpContext<DeviceTensor> *const op_context);
// Wait the MemoryManagerActor to finish running all current messages.
void Wait(OpContext<DeviceTensor> *op_context, const AID from_aid);
void Wait(OpContext<DeviceTensor> *const op_context, const AID from_aid);
};
} // namespace runtime
} // namespace mindspore

View File

@ -49,7 +49,7 @@ void OutputActor::Init() {
}
}
void OutputActor::CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *context) {
void OutputActor::CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
current_count_ = loop_count;
@ -106,7 +106,7 @@ void OutputActor::UpdateOutputDeviceAddress() {
}
void OutputActor::CollectOutput(const AnfNodePtr &output_node, size_t output_index, size_t output_position,
OpContext<DeviceTensor> *context) {
OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(output_node);
MS_EXCEPTION_IF_NULL(context);
// Collect the output result in the last loop which is represented by "loop_count_ - current_count_ == 1".

View File

@ -57,11 +57,11 @@ class OutputActor : public OpActor<DeviceTensor> {
bool IsActive(int msg_num) override { return msg_num >= running_dependent_msg_num_ ? true : false; }
// The output actor collects loop count when receive the input control of loop count actor.
void CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *context);
void CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *const context);
// The output actor collects output result when receive the data of actor.
void CollectOutput(const AnfNodePtr &output_node, size_t output_index, size_t output_position,
OpContext<DeviceTensor> *context);
OpContext<DeviceTensor> *const context);
// The graph output need be set new device address every step or loop, to avoid that the device address
// context of tensor be rewritten in the next step or next loop.

View File

@ -24,7 +24,7 @@
namespace mindspore {
namespace runtime {
void RecorderActor::RecordInfo(const std::string op_name, const KernelLaunchInfo *launch_info_,
const DeviceContext *device_context, OpContext<DeviceTensor> *op_context) {
const DeviceContext *device_context, OpContext<DeviceTensor> *const op_context) {
MS_EXCEPTION_IF_NULL(launch_info_);
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(op_context);
@ -62,7 +62,7 @@ void RecorderActor::RecordInfo(const std::string op_name, const KernelLaunchInfo
#endif
}
void RecorderActor::RecordOnStepEnd(OpContext<DeviceTensor> *op_context) {
void RecorderActor::RecordOnStepEnd(OpContext<DeviceTensor> *const op_context) {
MS_EXCEPTION_IF_NULL(op_context);
// todo clear
// Record iter_start, fp_start and iter_end op name and timestamp at the step end. (GPU)

View File

@ -37,11 +37,11 @@ class RecorderActor : public ActorBase {
// The memory recorder of each node.
void RecordInfo(const std::string op_name, const KernelLaunchInfo *launch_info_, const DeviceContext *device_context,
OpContext<DeviceTensor> *op_context);
OpContext<DeviceTensor> *const op_context);
// Clear memory recorder at the step end.
// Record fp_start and iter_end op name and timestamp at the step end. (GPU)
void RecordOnStepEnd(OpContext<DeviceTensor> *op_context);
void RecordOnStepEnd(OpContext<DeviceTensor> *const op_context);
};
} // namespace runtime
} // namespace mindspore

View File

@ -38,7 +38,7 @@ void SwitchActor::Init() {
}
}
void SwitchActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) {
void SwitchActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
const auto &sequential_num = context->sequential_num_;
auto &input_datas = input_data_[sequential_num];
@ -477,7 +477,7 @@ void SwitchActor::EraseInput(OpContext<DeviceTensor> *context) {
}
}
void SwitchActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) {
void SwitchActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensors_, device_context_, context);
}

View File

@ -71,7 +71,7 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
void Init() override;
// The switch actor run when receive the input data.
void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context);
void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *const context);
// The switch actor run when receive the input control.
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context);
// The switch actor run when receive the input branch id.
@ -108,7 +108,7 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
void SendOutput(OpContext<DeviceTensor> *context);
// Erase input data and input controls when finish switch launch.
void EraseInput(OpContext<DeviceTensor> *context);
void SendMemoryFreeReq(OpContext<DeviceTensor> *context);
void SendMemoryFreeReq(OpContext<DeviceTensor> *const context);
// Collect all the backend inputs of switch actor.
void FetchInputNode(const ControlNodeParserPtr &parser);

View File

@ -54,7 +54,7 @@ bool IsNeedInsertCopyActor(const DeviceContext *from_devcie_context, const Devic
}
}
void UpdateRefCount(DeviceTensor *device_tensor, bool is_max_ref_count = false) {
void UpdateRefCount(DeviceTensor *const device_tensor, bool is_max_ref_count = false) {
MS_EXCEPTION_IF_NULL(device_tensor);
if (is_max_ref_count) {
device_tensor->set_original_ref_count(SIZE_MAX);
@ -1433,14 +1433,17 @@ void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &interna
if (IsDeviceQueueDSActor(front_output_node)) {
auto from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor_pair.first);
MS_EXCEPTION_IF_NULL(from_actor);
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->data_kernel_, actor_pair.second);
LinkDataArrowForDeviceDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else if (IsKernelActor(front_output_node)) {
auto from_actor = dynamic_cast<KernelActor *>(actor_pair.first);
MS_EXCEPTION_IF_NULL(from_actor);
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->kernel_, actor_pair.second);
LinkDataArrowForKernelActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else if (IsHostQueueDSActor(front_output_node, graph, nullptr, host_parameters)) {
auto from_actor = dynamic_cast<HostQueueDataSourceActor *>(actor_pair.first);
MS_EXCEPTION_IF_NULL(from_actor);
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->data_nodes_[actor_pair.second], 0);
LinkDataArrowForHostDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else {
@ -1448,7 +1451,8 @@ void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &interna
}
}
void GraphScheduler::LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor,
void GraphScheduler::LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *const from_actor,
KernelActor *const to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx) {
MS_EXCEPTION_IF_NULL(from_actor);
@ -1473,7 +1477,8 @@ void GraphScheduler::LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *f
}
}
void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_actor, KernelActor *to_actor,
void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *const from_actor,
KernelActor *const to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx) {
MS_EXCEPTION_IF_NULL(from_actor);
@ -1501,7 +1506,7 @@ void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_
}
}
void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor,
void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *const to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx) {
MS_EXCEPTION_IF_NULL(to_actor);
@ -1541,7 +1546,7 @@ void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, Kernel
}
}
void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *from_actor, KernelActor *to_actor,
void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *const from_actor, KernelActor *const to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx) {
MS_EXCEPTION_IF_NULL(from_actor);
@ -1571,14 +1576,17 @@ void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *from_actor
auto op_arrow_to_copy = std::make_shared<DataArrow>(from_output_index, copy_actor->GetAID(), 0);
if (IsDeviceQueueDSActor(from_kernel)) {
auto real_from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(from_actor);
MS_EXCEPTION_IF_NULL(real_from_actor);
from_devcie_context = real_from_actor->device_context_;
real_from_actor->output_data_arrows_.emplace_back(op_arrow_to_copy);
} else if (IsKernelActor(from_kernel)) {
auto real_from_actor = dynamic_cast<KernelActor *>(from_actor);
MS_EXCEPTION_IF_NULL(real_from_actor);
from_devcie_context = real_from_actor->device_context_;
real_from_actor->output_data_arrows_.emplace_back(op_arrow_to_copy);
} else if (IsHostQueueDSActor(from_kernel)) {
auto real_from_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
MS_EXCEPTION_IF_NULL(real_from_actor);
auto position = real_from_actor->FetchDataNodePosition(from_kernel);
from_devcie_context = real_from_actor->device_contexts_[position];
op_arrow_to_copy->from_output_index_ = position;

View File

@ -203,15 +203,15 @@ class GraphScheduler {
const std::vector<AnfNodePtr> &host_parameters, const KernelGraphPtr &graph,
KernelActor *to_actor, KernelWithIndex to_kernel_with_input_idx);
// Link data arrows in the copy actor scene, insert the copy actor between from_actor and to_actor.
void LinkDataArrowForCopyActor(OpActor<DeviceTensor> *from_actor, KernelActor *to_actor,
void LinkDataArrowForCopyActor(OpActor<DeviceTensor> *const from_actor, KernelActor *const to_actor,
KernelWithIndex from_kernel_with_output_idx, KernelWithIndex to_kernel_with_input_idx);
void LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor,
void LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *const from_actor, KernelActor *const to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_to_kernel_with_input_idx);
void LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_actor, KernelActor *to_actor,
void LinkDataArrowForHostDSActor(HostQueueDataSourceActor *const from_actor, KernelActor *const to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx);
void LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor,
void LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *const to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx);