!21091 unified runtime codedex fixed
Merge pull request !21091 from limingqi107/bug_fix4
This commit is contained in:
commit
0816f95653
|
@ -277,6 +277,7 @@ void DynamicMemPoolBestFit::ReleaseDeviceRes() {
|
|||
if (!FreeDeviceMem(device_addr)) {
|
||||
MS_LOG(EXCEPTION) << "Free device memory[" << device_addr << "] error.";
|
||||
}
|
||||
device_addr = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -319,10 +319,10 @@ class KernelGraph : public FuncGraph {
|
|||
void InsertToSendRecvPair(const CNodePtr &allreduce, const std::pair<CNodePtr, CNodePtr> &send_recv_pair) {
|
||||
allreduce_to_send_recv_pairs_[allreduce] = send_recv_pair;
|
||||
}
|
||||
std::unordered_map<CNodePtr, std::pair<CNodePtr, CNodePtr>> &allreduce_from_send_recv_pairs() {
|
||||
const std::unordered_map<CNodePtr, std::pair<CNodePtr, CNodePtr>> &allreduce_from_send_recv_pairs() const {
|
||||
return allreduce_from_send_recv_pairs_;
|
||||
}
|
||||
std::unordered_map<CNodePtr, std::pair<CNodePtr, CNodePtr>> &allreduce_to_send_recv_pairs() {
|
||||
const std::unordered_map<CNodePtr, std::pair<CNodePtr, CNodePtr>> &allreduce_to_send_recv_pairs() const {
|
||||
return allreduce_to_send_recv_pairs_;
|
||||
}
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ class DumpJsonParser {
|
|||
|
||||
void ClearGraph() { graphs_.clear(); }
|
||||
void SaveGraph(session::KernelGraph *graph) { (void)graphs_.emplace_back(graph); }
|
||||
std::vector<session::KernelGraph *> &graphs() { return graphs_; }
|
||||
const std::vector<session::KernelGraph *> &graphs() const { return graphs_; }
|
||||
|
||||
private:
|
||||
DumpJsonParser() = default;
|
||||
|
|
|
@ -158,7 +158,7 @@ void AscendDeviceAddress::SyncStream() const {
|
|||
MS_LOG(DEBUG) << "Finish!";
|
||||
}
|
||||
|
||||
bool AscendDeviceAddress::SyncDeviceToHost(size_t size, void *host_ptr) const {
|
||||
bool AscendDeviceAddress::SyncDeviceToHost(size_t size, void *const host_ptr) const {
|
||||
MS_EXCEPTION_IF_NULL(host_ptr);
|
||||
SyncStream();
|
||||
SyncMemory(host_ptr, ptr_, size, RT_MEMCPY_DEVICE_TO_HOST);
|
||||
|
|
|
@ -42,7 +42,7 @@ class AscendDeviceAddress : public DeviceAddress {
|
|||
const KernelWithIndex &node_index)
|
||||
: DeviceAddress(ptr, size, format, type_id, node_index) {}
|
||||
~AscendDeviceAddress() override;
|
||||
bool SyncDeviceToHost(size_t size, void *host_ptr) const override;
|
||||
bool SyncDeviceToHost(size_t size, void *const host_ptr) const override;
|
||||
bool SyncHostToDevice(size_t size, const void *host_ptr) const override;
|
||||
bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override;
|
||||
bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
|
||||
|
|
|
@ -106,7 +106,7 @@ void IntToLong(void *dst, const void *src, size_t elem_num) {
|
|||
}
|
||||
}
|
||||
|
||||
void ConvertSameType(void *dst, const void *src, size_t size, TypeId type) {
|
||||
void ConvertSameType(void *const dst, const void *src, size_t size, TypeId type) {
|
||||
if (dst == nullptr || src == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ void ShortToInt(void *dst, const void *src, size_t elem_num);
|
|||
void IntToShort(void *dst, const void *src, size_t elem_num);
|
||||
void LongToInt(void *dst, const void *src, size_t elem_num);
|
||||
void IntToLong(void *dst, const void *src, size_t elem_num);
|
||||
void ConvertSameType(void *dst, const void *src, size_t size, TypeId type);
|
||||
void ConvertSameType(void *const dst, const void *src, size_t size, TypeId type);
|
||||
|
||||
template <typename T>
|
||||
void ConvertSameType(T *dst, const T *src, size_t elem_num) {
|
||||
|
|
|
@ -76,6 +76,8 @@ bool CPUDeviceAddress::SyncDeviceToHost(const ShapeVector &, size_t size, TypeId
|
|||
} else if (ret_code != EOK) {
|
||||
MS_LOG(ERROR) << "Failed to copy tensor!";
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
} else if (type == kNumberTypeFloat16 && type_id_ == kNumberTypeFloat32) {
|
||||
FloatToHalf(host_ptr, ptr_, size >> 1);
|
||||
|
|
|
@ -70,8 +70,8 @@ class KernelInfo : public KernelInfoDevice {
|
|||
uint32_t graph_id() const { return graph_id_; }
|
||||
bool operator==(const KernelInfo &other) const;
|
||||
bool is_feature_map() const { return is_feature_map_; }
|
||||
std::vector<std::shared_ptr<DeviceAddress>> &output_address_list() { return output_address_list_; }
|
||||
std::vector<std::shared_ptr<DeviceAddress>> &workspace_address_list() { return workspace_address_list_; }
|
||||
const std::vector<std::shared_ptr<DeviceAddress>> &output_address_list() const { return output_address_list_; }
|
||||
const std::vector<std::shared_ptr<DeviceAddress>> &workspace_address_list() const { return workspace_address_list_; }
|
||||
|
||||
private:
|
||||
bool is_feature_map_;
|
||||
|
|
|
@ -114,7 +114,7 @@ bool IsGatherActor(const AnfNodePtr &front_node,
|
|||
return false;
|
||||
}
|
||||
|
||||
bool Copy(DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor) {
|
||||
bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor) {
|
||||
MS_EXCEPTION_IF_NULL(dst_device_tensor);
|
||||
MS_EXCEPTION_IF_NULL(src_device_tensor);
|
||||
if (src_device_tensor->GetSize() != dst_device_tensor->GetSize()) {
|
||||
|
|
|
@ -89,7 +89,7 @@ bool IsGatherActor(const AnfNodePtr &front_node,
|
|||
const std::unordered_map<std::string, OpActor<DeviceTensor> *> &actor_name_to_actor);
|
||||
|
||||
// Copy data from src_device_tensor to dst_device_tensor.
|
||||
bool Copy(DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor);
|
||||
bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor);
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ void CopyActor::Init() {
|
|||
}
|
||||
}
|
||||
|
||||
void CopyActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) {
|
||||
void CopyActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
input_op_datas_[sequential_num].emplace_back(input_data);
|
||||
|
@ -49,7 +49,7 @@ void CopyActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTens
|
|||
}
|
||||
}
|
||||
|
||||
void CopyActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) {
|
||||
void CopyActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
input_op_controls_[sequential_num].emplace_back(input_control);
|
||||
|
@ -60,17 +60,17 @@ void CopyActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *contex
|
|||
}
|
||||
}
|
||||
|
||||
void CopyActor::SendMemoryAllocReq(OpContext<DeviceTensor> *context) {
|
||||
void CopyActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &output_device_tensor_, output_device_context_,
|
||||
context, GetAID());
|
||||
}
|
||||
|
||||
void CopyActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) {
|
||||
void CopyActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensor_, input_device_context_, context);
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &output_device_tensor_, output_device_context_, context);
|
||||
}
|
||||
|
||||
void CopyActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
|
||||
void CopyActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
MS_EXCEPTION_IF_NULL(output_device_tensor_[0]);
|
||||
MS_EXCEPTION_IF_NULL(input_device_tensor_[0]);
|
||||
|
@ -96,7 +96,7 @@ void CopyActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
|
|||
SendOutput(context);
|
||||
}
|
||||
|
||||
bool CopyActor::CheckCopyCondition(OpContext<DeviceTensor> *context) const {
|
||||
bool CopyActor::CheckCopyCondition(OpContext<DeviceTensor> *const context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_datas_num_ != 0) {
|
||||
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
|
||||
|
@ -120,7 +120,7 @@ bool CopyActor::CheckCopyCondition(OpContext<DeviceTensor> *context) const {
|
|||
return true;
|
||||
}
|
||||
|
||||
void CopyActor::FetchDeviceTensor(OpContext<DeviceTensor> *context) {
|
||||
void CopyActor::FetchDeviceTensor(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
MS_EXCEPTION_IF_NULL(input_device_context_);
|
||||
|
||||
|
@ -156,7 +156,7 @@ void CopyActor::FetchDeviceTensor(OpContext<DeviceTensor> *context) {
|
|||
}
|
||||
}
|
||||
|
||||
void CopyActor::SendOutput(OpContext<DeviceTensor> *context) const {
|
||||
void CopyActor::SendOutput(OpContext<DeviceTensor> *const context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
// No output.
|
||||
if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0)) {
|
||||
|
@ -179,7 +179,7 @@ void CopyActor::SendOutput(OpContext<DeviceTensor> *context) const {
|
|||
}
|
||||
}
|
||||
|
||||
void CopyActor::EraseInput(OpContext<DeviceTensor> *context) {
|
||||
void CopyActor::EraseInput(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_datas_num_ != 0) {
|
||||
auto ret = input_op_datas_.erase(context->sequential_num_);
|
||||
|
|
|
@ -41,34 +41,36 @@ class CopyActor : public MemoryAwareActor {
|
|||
memory_manager_aid_(memory_manager_aid),
|
||||
input_datas_num_(0),
|
||||
input_controls_num_(0),
|
||||
input_device_context_(nullptr),
|
||||
output_device_context_(nullptr),
|
||||
output_(nullptr) {}
|
||||
~CopyActor() override = default;
|
||||
|
||||
void Init() override;
|
||||
|
||||
// The copy actor run when receive the input data.
|
||||
void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) override;
|
||||
void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;
|
||||
// The copy actor run when receive the input control.
|
||||
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override;
|
||||
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
// The memory related operation interface.
|
||||
void SendMemoryAllocReq(OpContext<DeviceTensor> *context) override;
|
||||
void SendMemoryFreeReq(OpContext<DeviceTensor> *context) override;
|
||||
void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override;
|
||||
void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;
|
||||
// The copy processing after memory alloc finished.
|
||||
void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override;
|
||||
void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
// Check whether satisfy the condition for copy.
|
||||
bool CheckCopyCondition(OpContext<DeviceTensor> *context) const;
|
||||
bool CheckCopyCondition(OpContext<DeviceTensor> *const context) const;
|
||||
// Fetch the device tensor for copy.
|
||||
void FetchDeviceTensor(OpContext<DeviceTensor> *context);
|
||||
void FetchDeviceTensor(OpContext<DeviceTensor> *const context);
|
||||
|
||||
// Send output data and output controls when finish copy.
|
||||
void SendOutput(OpContext<DeviceTensor> *context) const;
|
||||
void SendOutput(OpContext<DeviceTensor> *const context) const;
|
||||
// Erase input data and input controls when finish copy.
|
||||
void EraseInput(OpContext<DeviceTensor> *context);
|
||||
void EraseInput(OpContext<DeviceTensor> *const context);
|
||||
|
||||
// The id of memory manager actor. Send message to it for alloc and free memory during the copy.
|
||||
const AID memory_manager_aid_;
|
||||
|
|
|
@ -35,7 +35,7 @@ void DataSourceActor::Init() {
|
|||
}
|
||||
}
|
||||
|
||||
void DataSourceActor::FetchData(OpContext<DeviceTensor> *context) {
|
||||
void DataSourceActor::FetchData(OpContext<DeviceTensor> *const context) {
|
||||
MS_LOG(INFO) << "Data source actor(" << GetAID().Name() << ") fetches data.";
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
// Pop the data of last time.
|
||||
|
@ -53,7 +53,7 @@ void DataSourceActor::FetchData(OpContext<DeviceTensor> *context) {
|
|||
SendMemoryAllocReq(context);
|
||||
}
|
||||
|
||||
void DataSourceActor::SendOutput(OpContext<DeviceTensor> *context) {
|
||||
void DataSourceActor::SendOutput(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
// No output.
|
||||
if ((output_data_arrows_.size() == 0) && (output_control_arrows_.size() == 0) &&
|
||||
|
@ -124,17 +124,17 @@ void DeviceQueueDataSourceActor::FillDataBuffer() {
|
|||
buffers_.push(device_tensors);
|
||||
}
|
||||
|
||||
void DeviceQueueDataSourceActor::SendMemoryAllocReq(OpContext<DeviceTensor> *context) {
|
||||
void DeviceQueueDataSourceActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
|
||||
auto &device_tensors = buffers_.back();
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors, device_context_, context, GetAID());
|
||||
}
|
||||
|
||||
void DeviceQueueDataSourceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) {
|
||||
void DeviceQueueDataSourceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
|
||||
auto &device_tensors = buffers_.front();
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors, device_context_, context);
|
||||
}
|
||||
|
||||
void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
|
||||
void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
MS_EXCEPTION_IF_NULL(device_context_);
|
||||
if (buffers_.size() == 0) {
|
||||
|
@ -177,16 +177,16 @@ void DeviceQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *co
|
|||
SendOutput(context);
|
||||
}
|
||||
|
||||
void DeviceQueueDataSourceActor::SendDebugReq(OpContext<DeviceTensor> *context) {
|
||||
void DeviceQueueDataSourceActor::SendDebugReq(OpContext<DeviceTensor> *const context) {
|
||||
Async(*debug_aid_, &DebugActor::Debug, data_kernel_, &launch_info_, device_context_, context, &GetAID());
|
||||
}
|
||||
|
||||
void DeviceQueueDataSourceActor::OnDebugFinish(OpContext<DeviceTensor> *context) {
|
||||
void DeviceQueueDataSourceActor::OnDebugFinish(OpContext<DeviceTensor> *const context) {
|
||||
SendMemoryFreeReq(context);
|
||||
SendOutput(context);
|
||||
}
|
||||
|
||||
void DeviceQueueDataSourceActor::SendResult(OpContext<DeviceTensor> *context) {
|
||||
void DeviceQueueDataSourceActor::SendResult(OpContext<DeviceTensor> *const context) {
|
||||
for (const auto &result_arrow : output_result_arrows_) {
|
||||
MS_EXCEPTION_IF_NULL(result_arrow);
|
||||
Async(result_arrow->to_op_id_, &OutputActor::CollectOutput, data_kernel_, result_arrow->from_output_index_,
|
||||
|
@ -194,7 +194,7 @@ void DeviceQueueDataSourceActor::SendResult(OpContext<DeviceTensor> *context) {
|
|||
}
|
||||
}
|
||||
|
||||
void DeviceQueueDataSourceActor::SendRecorderInfo(OpContext<DeviceTensor> *context) {
|
||||
void DeviceQueueDataSourceActor::SendRecorderInfo(OpContext<DeviceTensor> *const context) {
|
||||
if (recorder_aid_ != nullptr) {
|
||||
Async(*recorder_aid_, &RecorderActor::RecordInfo, data_kernel_->fullname_with_scope(), &launch_info_,
|
||||
device_context_, context);
|
||||
|
@ -213,7 +213,7 @@ void HostQueueDataSourceActor::FillDataBuffer() {
|
|||
buffers_.push(device_tensors);
|
||||
}
|
||||
|
||||
void HostQueueDataSourceActor::SendMemoryAllocReq(OpContext<DeviceTensor> *context) {
|
||||
void HostQueueDataSourceActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
|
||||
auto &device_tensors = buffers_.back();
|
||||
if (IsSameDeviceType()) {
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &device_tensors, device_contexts_[0], context,
|
||||
|
@ -224,7 +224,7 @@ void HostQueueDataSourceActor::SendMemoryAllocReq(OpContext<DeviceTensor> *conte
|
|||
}
|
||||
}
|
||||
|
||||
void HostQueueDataSourceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) {
|
||||
void HostQueueDataSourceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
|
||||
auto &device_tensors = buffers_.front();
|
||||
if (IsSameDeviceType()) {
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &device_tensors, device_contexts_[0], context);
|
||||
|
@ -233,7 +233,7 @@ void HostQueueDataSourceActor::SendMemoryFreeReq(OpContext<DeviceTensor> *contex
|
|||
}
|
||||
}
|
||||
|
||||
void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
|
||||
void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (buffers_.size() == 0) {
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty.");
|
||||
|
@ -283,7 +283,7 @@ void HostQueueDataSourceActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *cont
|
|||
SendOutput(context);
|
||||
}
|
||||
|
||||
void HostQueueDataSourceActor::SendResult(OpContext<DeviceTensor> *context) {
|
||||
void HostQueueDataSourceActor::SendResult(OpContext<DeviceTensor> *const context) {
|
||||
for (const auto &result_arrow : output_result_arrows_) {
|
||||
MS_EXCEPTION_IF_NULL(result_arrow);
|
||||
if (IntToSize(result_arrow->from_output_index_) >= data_nodes_.size()) {
|
||||
|
|
|
@ -53,13 +53,13 @@ class DataSourceActor : public DebugAwareActor {
|
|||
void Init() override;
|
||||
|
||||
// The process entry of data processing.
|
||||
void FetchData(OpContext<DeviceTensor> *context);
|
||||
void FetchData(OpContext<DeviceTensor> *const context);
|
||||
|
||||
// The memory related operation interface.
|
||||
void SendMemoryAllocReq(OpContext<DeviceTensor> *context) override{};
|
||||
void SendMemoryFreeReq(OpContext<DeviceTensor> *context) override{};
|
||||
void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override{};
|
||||
void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override{};
|
||||
// Copy data from data source to the device tensor buffer of actor after memory alloc finished.
|
||||
void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override{};
|
||||
void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override{};
|
||||
|
||||
protected:
|
||||
friend class GraphScheduler;
|
||||
|
@ -68,13 +68,13 @@ class DataSourceActor : public DebugAwareActor {
|
|||
virtual void FillDataBuffer() = 0;
|
||||
|
||||
// Send output result of graph output to output actor.
|
||||
virtual void SendResult(OpContext<DeviceTensor> *context) = 0;
|
||||
virtual void SendResult(OpContext<DeviceTensor> *const context) = 0;
|
||||
|
||||
// Send recorder info to recorder actor, only the device queue data source actor need.
|
||||
virtual void SendRecorderInfo(OpContext<DeviceTensor> *context) {}
|
||||
virtual void SendRecorderInfo(OpContext<DeviceTensor> *const context) {}
|
||||
|
||||
// Send output to downstream actors to trigger computing after fetching data finished.
|
||||
void SendOutput(OpContext<DeviceTensor> *context);
|
||||
void SendOutput(OpContext<DeviceTensor> *const context);
|
||||
|
||||
// The output result arrows of graph output.
|
||||
std::vector<DataArrowPtr> output_result_arrows_;
|
||||
|
@ -105,17 +105,17 @@ class DeviceQueueDataSourceActor : public DataSourceActor {
|
|||
|
||||
void Init() override;
|
||||
|
||||
void SendMemoryAllocReq(OpContext<DeviceTensor> *context) override;
|
||||
void SendMemoryFreeReq(OpContext<DeviceTensor> *context) override;
|
||||
void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override;
|
||||
void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override;
|
||||
void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;
|
||||
void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
void SendDebugReq(OpContext<DeviceTensor> *context) override;
|
||||
void OnDebugFinish(OpContext<DeviceTensor> *context) override;
|
||||
void SendDebugReq(OpContext<DeviceTensor> *const context) override;
|
||||
void OnDebugFinish(OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
protected:
|
||||
void FillDataBuffer() override;
|
||||
void SendResult(OpContext<DeviceTensor> *context) override;
|
||||
void SendRecorderInfo(OpContext<DeviceTensor> *context) override;
|
||||
void SendResult(OpContext<DeviceTensor> *const context) override;
|
||||
void SendRecorderInfo(OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
@ -138,15 +138,15 @@ class HostQueueDataSourceActor : public DataSourceActor {
|
|||
: DataSourceActor(name, buffer_capacity, memory_manager_aid, debug_aid, recorder_aid), host_queue_(host_queue) {}
|
||||
~HostQueueDataSourceActor() override = default;
|
||||
|
||||
void SendMemoryAllocReq(OpContext<DeviceTensor> *context) override;
|
||||
void SendMemoryFreeReq(OpContext<DeviceTensor> *context) override;
|
||||
void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override;
|
||||
void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override;
|
||||
void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;
|
||||
void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
size_t FetchDataNodePosition(const AnfNodePtr &data_node) const;
|
||||
|
||||
protected:
|
||||
void FillDataBuffer() override;
|
||||
void SendResult(OpContext<DeviceTensor> *context) override;
|
||||
void SendResult(OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
|
|
@ -31,7 +31,8 @@ namespace mindspore {
|
|||
namespace runtime {
|
||||
|
||||
void DebugActor::Debug(const AnfNodePtr &node, const KernelLaunchInfo *launch_info_,
|
||||
const DeviceContext *device_context, OpContext<DeviceTensor> *op_context, const AID *from_aid) {
|
||||
const DeviceContext *device_context, OpContext<DeviceTensor> *const op_context,
|
||||
const AID *from_aid) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
MS_EXCEPTION_IF_NULL(op_context);
|
||||
|
@ -69,7 +70,7 @@ void DebugActor::Debug(const AnfNodePtr &node, const KernelLaunchInfo *launch_in
|
|||
Async(*from_aid, &DebugAwareActor::OnDebugFinish, op_context);
|
||||
}
|
||||
|
||||
void DebugActor::DebugOnStepEnd(OpContext<DeviceTensor> *op_context, const AID *from_aid) {
|
||||
void DebugActor::DebugOnStepEnd(OpContext<DeviceTensor> *const op_context, const AID *from_aid) {
|
||||
MS_EXCEPTION_IF_NULL(op_context);
|
||||
MS_EXCEPTION_IF_NULL(from_aid);
|
||||
|
||||
|
|
|
@ -35,10 +35,10 @@ class DebugActor : public ActorBase {
|
|||
|
||||
// The debug of each node.
|
||||
void Debug(const AnfNodePtr &node, const KernelLaunchInfo *launch_info_, const DeviceContext *device_context,
|
||||
OpContext<DeviceTensor> *op_context, const AID *from_aid);
|
||||
OpContext<DeviceTensor> *const op_context, const AID *from_aid);
|
||||
|
||||
// The debug on step end.
|
||||
void DebugOnStepEnd(OpContext<DeviceTensor> *op_context, const AID *from_aid);
|
||||
void DebugOnStepEnd(OpContext<DeviceTensor> *const op_context, const AID *from_aid);
|
||||
|
||||
private:
|
||||
// class members
|
||||
|
|
|
@ -27,8 +27,8 @@ class DebugAwareActor : public MemoryAwareActor {
|
|||
public:
|
||||
explicit DebugAwareActor(const std::string &name) : MemoryAwareActor(name) {}
|
||||
virtual ~DebugAwareActor() = default;
|
||||
virtual void SendDebugReq(OpContext<DeviceTensor> *context) {}
|
||||
virtual void OnDebugFinish(OpContext<DeviceTensor> *context) {}
|
||||
virtual void SendDebugReq(OpContext<DeviceTensor> *const context) {}
|
||||
virtual void OnDebugFinish(OpContext<DeviceTensor> *const context) {}
|
||||
};
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -74,7 +74,7 @@ void KernelActor::Init() {
|
|||
}
|
||||
}
|
||||
|
||||
void KernelActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) {
|
||||
void KernelActor::RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
input_op_datas_[sequential_num].emplace_back(input_data);
|
||||
|
@ -100,7 +100,7 @@ void KernelActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTe
|
|||
}
|
||||
}
|
||||
|
||||
void KernelActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) {
|
||||
void KernelActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
input_op_controls_[sequential_num].emplace_back(input_control);
|
||||
|
@ -178,7 +178,7 @@ void FreeMemory(const std::vector<DeviceTensor *> &free_list, const DeviceContex
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void KernelActor::SendMemoryAllocReq(OpContext<DeviceTensor> *context) {
|
||||
void KernelActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
|
||||
running_dependent_msg_num_ = 1;
|
||||
if (strategy_ == GraphExecutionStrategy::kPipeline) {
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &memory_alloc_list_, device_context_, context,
|
||||
|
@ -188,7 +188,7 @@ void KernelActor::SendMemoryAllocReq(OpContext<DeviceTensor> *context) {
|
|||
}
|
||||
}
|
||||
|
||||
void KernelActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) {
|
||||
void KernelActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
|
||||
if (strategy_ == GraphExecutionStrategy::kPipeline) {
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &memory_free_list_, device_context_, context);
|
||||
} else {
|
||||
|
@ -196,7 +196,7 @@ void KernelActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) {
|
|||
}
|
||||
}
|
||||
|
||||
void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
|
||||
void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
MS_EXCEPTION_IF_NULL(kernel_);
|
||||
MS_EXCEPTION_IF_NULL(device_context_);
|
||||
|
@ -224,7 +224,7 @@ void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
|
|||
PostLaunchKernel(context);
|
||||
}
|
||||
|
||||
void KernelActor::SendDebugReq(OpContext<DeviceTensor> *context) {
|
||||
void KernelActor::SendDebugReq(OpContext<DeviceTensor> *const context) {
|
||||
running_dependent_msg_num_ = 1;
|
||||
Async(*debug_aid_, &DebugActor::Debug, kernel_, &launch_info_, device_context_, context, &GetAID());
|
||||
}
|
||||
|
@ -234,7 +234,7 @@ void KernelActor::OnDebugFinish(OpContext<DeviceTensor> *context) {
|
|||
PostLaunchKernel(context);
|
||||
}
|
||||
|
||||
bool KernelActor::CheckLaunchCondition(OpContext<DeviceTensor> *context) const {
|
||||
bool KernelActor::CheckLaunchCondition(OpContext<DeviceTensor> *const context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_datas_num_ != 0) {
|
||||
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
|
||||
|
@ -276,7 +276,8 @@ void KernelActor::PushInputDeviceTensor(const std::vector<TensorPtr> *input_tens
|
|||
}
|
||||
}
|
||||
|
||||
void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) {
|
||||
void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
|
||||
OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(input_data);
|
||||
if ((input_data->data_ == nullptr) || (input_data->data_->DeviceType() == device_context_->GetDeviceAddressType())) {
|
||||
return;
|
||||
|
@ -312,7 +313,7 @@ void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
|
|||
memory_free_list_[input_data->index_] = copy_input_device_tensors_[input_data->index_].get();
|
||||
}
|
||||
|
||||
void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) {
|
||||
void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
MS_EXCEPTION_IF_NULL(device_context_);
|
||||
|
||||
|
@ -392,7 +393,7 @@ void KernelActor::PreLaunchKernel(OpContext<DeviceTensor> *) {
|
|||
}
|
||||
}
|
||||
|
||||
void KernelActor::PostLaunchKernel(OpContext<DeviceTensor> *context) {
|
||||
void KernelActor::PostLaunchKernel(OpContext<DeviceTensor> *const context) {
|
||||
running_dependent_msg_num_ = SizeToInt(input_datas_num_ + input_controls_num_);
|
||||
|
||||
// The input is invalid and needs to be erased when finish kernel launch.
|
||||
|
@ -408,7 +409,7 @@ void KernelActor::PostLaunchKernel(OpContext<DeviceTensor> *context) {
|
|||
SendOutput(context);
|
||||
}
|
||||
|
||||
void KernelActor::SendOutput(OpContext<DeviceTensor> *context) const {
|
||||
void KernelActor::SendOutput(OpContext<DeviceTensor> *const context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (strategy_ == GraphExecutionStrategy::kStep) {
|
||||
return;
|
||||
|
@ -449,7 +450,7 @@ void KernelActor::SendOutput(OpContext<DeviceTensor> *context) const {
|
|||
}
|
||||
}
|
||||
|
||||
void KernelActor::EraseInput(OpContext<DeviceTensor> *context) {
|
||||
void KernelActor::EraseInput(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
if (input_datas_num_ != 0) {
|
||||
auto ret = input_op_datas_.erase(context->sequential_num_);
|
||||
|
|
|
@ -65,45 +65,45 @@ class KernelActor : public DebugAwareActor {
|
|||
bool IsActive(int msg_num) override { return msg_num >= running_dependent_msg_num_ ? true : false; }
|
||||
|
||||
// The kernel actor run when receive the input data.
|
||||
void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) override;
|
||||
void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;
|
||||
// The kernel actor run when receive the input control.
|
||||
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override;
|
||||
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;
|
||||
// The kernel actor run when receive the input control and input tensors, used in step mode.
|
||||
void RunOpControlWithInputTensor(AID *const input_control, OpContext<DeviceTensor> *const context,
|
||||
const std::vector<TensorPtr> *input_tensors);
|
||||
|
||||
// The memory related operation interface.
|
||||
void SendMemoryAllocReq(OpContext<DeviceTensor> *context) override;
|
||||
void SendMemoryFreeReq(OpContext<DeviceTensor> *context) override;
|
||||
void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override;
|
||||
void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;
|
||||
// The callback after memory alloc finished.
|
||||
void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override;
|
||||
void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
// The debug related operation interface.
|
||||
void SendDebugReq(OpContext<DeviceTensor> *context) override;
|
||||
void SendDebugReq(OpContext<DeviceTensor> *const context) override;
|
||||
// The callback after debug finished.
|
||||
void OnDebugFinish(OpContext<DeviceTensor> *context) override;
|
||||
void OnDebugFinish(OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
// Check whether satisfy the condition for launch.
|
||||
bool CheckLaunchCondition(OpContext<DeviceTensor> *context) const;
|
||||
bool CheckLaunchCondition(OpContext<DeviceTensor> *const context) const;
|
||||
// Fetch the device tensor for launch.
|
||||
void FetchInputDeviceTensor(OpContext<DeviceTensor> *context);
|
||||
void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context);
|
||||
void FetchOutputDeviceTensor();
|
||||
void CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context);
|
||||
void CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *const context);
|
||||
// In step mode, push the input tensors which contain valid device address into input_device_tensors_ directly.
|
||||
void PushInputDeviceTensor(const std::vector<TensorPtr> *input_tensors);
|
||||
|
||||
// The processing before kernel launch: update the info of kernel launch.
|
||||
void PreLaunchKernel(OpContext<DeviceTensor> *context);
|
||||
void PreLaunchKernel(OpContext<DeviceTensor> *const context);
|
||||
// The processing after kernel launch: 1.erase input, 2.free memory, 3.send output.
|
||||
void PostLaunchKernel(OpContext<DeviceTensor> *context);
|
||||
void PostLaunchKernel(OpContext<DeviceTensor> *const context);
|
||||
|
||||
// Send output data and output controls when finish kernel launch.
|
||||
void SendOutput(OpContext<DeviceTensor> *context) const;
|
||||
void SendOutput(OpContext<DeviceTensor> *const context) const;
|
||||
// Erase input data and input controls when finish kernel launch.
|
||||
void EraseInput(OpContext<DeviceTensor> *context);
|
||||
void EraseInput(OpContext<DeviceTensor> *const context);
|
||||
|
||||
// The info of kernel.
|
||||
CNodePtr kernel_;
|
||||
|
|
|
@ -27,8 +27,8 @@
|
|||
namespace mindspore {
|
||||
namespace runtime {
|
||||
namespace {
|
||||
void FetchContinuousMemoryInfo(const CNodePtr &node, std::vector<DeviceTensorPtr> *addr_list,
|
||||
std::vector<size_t> *size_list, size_t *total_size, bool is_input) {
|
||||
void FetchContinuousMemoryInfo(const CNodePtr &node, std::vector<DeviceTensorPtr> *const addr_list,
|
||||
std::vector<size_t> *const size_list, size_t *const total_size, bool is_input) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
const auto &kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
|
@ -82,7 +82,7 @@ void LoopCountActor::Init() {
|
|||
}
|
||||
}
|
||||
|
||||
void LoopCountActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) {
|
||||
void LoopCountActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto sequential_num = context->sequential_num_;
|
||||
input_op_controls_[sequential_num].emplace_back(input_control);
|
||||
|
@ -91,16 +91,16 @@ void LoopCountActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *c
|
|||
}
|
||||
}
|
||||
|
||||
void LoopCountActor::SendDebugReq(OpContext<DeviceTensor> *context) {
|
||||
void LoopCountActor::SendDebugReq(OpContext<DeviceTensor> *const context) {
|
||||
Async(*debug_aid_, &DebugActor::DebugOnStepEnd, context, &GetAID());
|
||||
}
|
||||
|
||||
void LoopCountActor::OnDebugFinish(OpContext<DeviceTensor> *context) {
|
||||
void LoopCountActor::OnDebugFinish(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
SendOutput(context);
|
||||
}
|
||||
|
||||
void LoopCountActor::IncreaseLoopCount(OpContext<DeviceTensor> *context) {
|
||||
void LoopCountActor::IncreaseLoopCount(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto sequential_num = context->sequential_num_;
|
||||
auto ret = input_op_controls_.erase(sequential_num);
|
||||
|
@ -123,7 +123,7 @@ void LoopCountActor::IncreaseLoopCount(OpContext<DeviceTensor> *context) {
|
|||
SendOutput(context);
|
||||
}
|
||||
|
||||
void LoopCountActor::SendOutput(OpContext<DeviceTensor> *context) {
|
||||
void LoopCountActor::SendOutput(OpContext<DeviceTensor> *const context) {
|
||||
// Send recorder info.
|
||||
if (recorder_aid_ != nullptr) {
|
||||
Async(*recorder_aid_, &RecorderActor::RecordOnStepEnd, context);
|
||||
|
@ -131,7 +131,7 @@ void LoopCountActor::SendOutput(OpContext<DeviceTensor> *context) {
|
|||
SendMemoryAllocReq(context);
|
||||
}
|
||||
|
||||
void LoopCountActor::SendMemoryAllocReq(OpContext<DeviceTensor> *context) {
|
||||
void LoopCountActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
|
||||
if (current_count_ == loop_count_) {
|
||||
// Need wait MemoryManagerActor running finished to avoid the illegal memory timing problem before
|
||||
// LoopCountActor exits, because other processors which are not in actor also will allocate or free memory.
|
||||
|
@ -145,7 +145,7 @@ void LoopCountActor::SendMemoryAllocReq(OpContext<DeviceTensor> *context) {
|
|||
}
|
||||
}
|
||||
|
||||
void LoopCountActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
|
||||
void LoopCountActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
// Send loop count to output actor.
|
||||
Async(output_aid_, &OutputActor::CollectLoopCount, current_count_, context);
|
||||
|
@ -166,7 +166,7 @@ void LoopCountActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
|
|||
}
|
||||
}
|
||||
|
||||
bool LoopCountActor::CheckLoopCountIncreaseCondition(OpContext<DeviceTensor> *context) {
|
||||
bool LoopCountActor::CheckLoopCountIncreaseCondition(OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto sequential_num = context->sequential_num_;
|
||||
|
||||
|
|
|
@ -50,25 +50,25 @@ class LoopCountActor : public DebugAwareActor {
|
|||
void Init() override;
|
||||
|
||||
// The loop count actor run when receive the input control.
|
||||
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override;
|
||||
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
// The memory related operation interface.
|
||||
void SendMemoryAllocReq(OpContext<DeviceTensor> *context) override;
|
||||
void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override;
|
||||
// The callback after memory alloc finished.
|
||||
void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override;
|
||||
void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
// The debug related operation interface.
|
||||
void SendDebugReq(OpContext<DeviceTensor> *context) override;
|
||||
void SendDebugReq(OpContext<DeviceTensor> *const context) override;
|
||||
// The callback after debug finished.
|
||||
void OnDebugFinish(OpContext<DeviceTensor> *context) override;
|
||||
void OnDebugFinish(OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
void IncreaseLoopCount(OpContext<DeviceTensor> *context);
|
||||
void SendOutput(OpContext<DeviceTensor> *context);
|
||||
void IncreaseLoopCount(OpContext<DeviceTensor> *const context);
|
||||
void SendOutput(OpContext<DeviceTensor> *const context);
|
||||
|
||||
bool CheckLoopCountIncreaseCondition(OpContext<DeviceTensor> *context);
|
||||
bool CheckLoopCountIncreaseCondition(OpContext<DeviceTensor> *const context);
|
||||
// The loop count is constant, the current count is increased after each step running finished.
|
||||
size_t loop_count_;
|
||||
size_t current_count_;
|
||||
|
|
|
@ -29,9 +29,9 @@ class MemoryAwareActor : public OpActor<DeviceTensor> {
|
|||
public:
|
||||
explicit MemoryAwareActor(std::string name) : OpActor(name) {}
|
||||
virtual ~MemoryAwareActor() = default;
|
||||
virtual void SendMemoryAllocReq(OpContext<DeviceTensor> *context) {}
|
||||
virtual void SendMemoryFreeReq(OpContext<DeviceTensor> *context) {}
|
||||
virtual void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {}
|
||||
virtual void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {}
|
||||
virtual void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {}
|
||||
virtual void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {}
|
||||
|
||||
friend class GraphScheduler;
|
||||
};
|
||||
|
|
|
@ -22,8 +22,9 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
void MemoryManagerActor::AllocateMemory(std::vector<DeviceTensor *> *alloc_list, const DeviceContext *device_context,
|
||||
OpContext<DeviceTensor> *op_context, const AID from_aid) {
|
||||
void MemoryManagerActor::AllocateMemory(const std::vector<DeviceTensor *> *alloc_list,
|
||||
const DeviceContext *device_context, OpContext<DeviceTensor> *const op_context,
|
||||
const AID from_aid) {
|
||||
MS_EXCEPTION_IF_NULL(alloc_list);
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
MS_EXCEPTION_IF_NULL(op_context);
|
||||
|
@ -46,11 +47,11 @@ void MemoryManagerActor::AllocateMemory(std::vector<DeviceTensor *> *alloc_list,
|
|||
Async(from_aid, &MemoryAwareActor::OnMemoryAllocFinish, op_context);
|
||||
}
|
||||
|
||||
void MemoryManagerActor::AllocateContinuousMemory(std::vector<std::vector<DeviceTensorPtr>> *alloc_list_list,
|
||||
std::vector<std::vector<size_t>> *size_list_list,
|
||||
std::vector<size_t> *total_size_list,
|
||||
std::vector<const DeviceContext *> *device_contexts,
|
||||
OpContext<DeviceTensor> *op_context, const AID from_aid) {
|
||||
void MemoryManagerActor::AllocateContinuousMemory(const std::vector<std::vector<DeviceTensorPtr>> *alloc_list_list,
|
||||
const std::vector<std::vector<size_t>> *size_list_list,
|
||||
const std::vector<size_t> *total_size_list,
|
||||
const std::vector<const DeviceContext *> *device_contexts,
|
||||
OpContext<DeviceTensor> *const op_context, const AID from_aid) {
|
||||
MS_EXCEPTION_IF_NULL(alloc_list_list);
|
||||
MS_EXCEPTION_IF_NULL(size_list_list);
|
||||
MS_EXCEPTION_IF_NULL(total_size_list);
|
||||
|
@ -81,9 +82,9 @@ void MemoryManagerActor::AllocateContinuousMemory(std::vector<std::vector<Device
|
|||
Async(from_aid, &MemoryAwareActor::OnMemoryAllocFinish, op_context);
|
||||
}
|
||||
|
||||
void MemoryManagerActor::AllocateBatchMemory(std::vector<DeviceTensor *> *alloc_list,
|
||||
std::vector<const DeviceContext *> *device_contexts,
|
||||
OpContext<DeviceTensor> *op_context, const AID from_aid) {
|
||||
void MemoryManagerActor::AllocateBatchMemory(const std::vector<DeviceTensor *> *alloc_list,
|
||||
const std::vector<const DeviceContext *> *device_contexts,
|
||||
OpContext<DeviceTensor> *const op_context, const AID from_aid) {
|
||||
MS_EXCEPTION_IF_NULL(alloc_list);
|
||||
MS_EXCEPTION_IF_NULL(device_contexts);
|
||||
MS_EXCEPTION_IF_NULL(op_context);
|
||||
|
@ -114,7 +115,7 @@ void MemoryManagerActor::AllocateBatchMemory(std::vector<DeviceTensor *> *alloc_
|
|||
Async(from_aid, &MemoryAwareActor::OnMemoryAllocFinish, op_context);
|
||||
}
|
||||
|
||||
void MemoryManagerActor::FreeMemory(std::vector<DeviceTensor *> *free_list, const DeviceContext *device_context,
|
||||
void MemoryManagerActor::FreeMemory(const std::vector<DeviceTensor *> *free_list, const DeviceContext *device_context,
|
||||
OpContext<DeviceTensor> *) {
|
||||
MS_EXCEPTION_IF_NULL(free_list);
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
|
@ -135,9 +136,9 @@ void MemoryManagerActor::FreeMemory(std::vector<DeviceTensor *> *free_list, cons
|
|||
}
|
||||
}
|
||||
|
||||
void MemoryManagerActor::FreeBatchMemory(std::vector<DeviceTensor *> *free_list,
|
||||
std::vector<const DeviceContext *> *device_contexts,
|
||||
OpContext<DeviceTensor> *op_context) {
|
||||
void MemoryManagerActor::FreeBatchMemory(const std::vector<DeviceTensor *> *free_list,
|
||||
const std::vector<const DeviceContext *> *device_contexts,
|
||||
OpContext<DeviceTensor> *const op_context) {
|
||||
MS_EXCEPTION_IF_NULL(free_list);
|
||||
MS_EXCEPTION_IF_NULL(device_contexts);
|
||||
MS_EXCEPTION_IF_NULL(op_context);
|
||||
|
@ -166,7 +167,7 @@ void MemoryManagerActor::FreeBatchMemory(std::vector<DeviceTensor *> *free_list,
|
|||
}
|
||||
}
|
||||
|
||||
void MemoryManagerActor::Wait(OpContext<DeviceTensor> *op_context, const AID from_aid) {
|
||||
void MemoryManagerActor::Wait(OpContext<DeviceTensor> *const op_context, const AID from_aid) {
|
||||
// Call back to the from actor to process.
|
||||
Async(from_aid, &MemoryAwareActor::OnMemoryAllocFinish, op_context);
|
||||
}
|
||||
|
|
|
@ -36,27 +36,30 @@ class MemoryManagerActor : public ActorBase {
|
|||
~MemoryManagerActor() override = default;
|
||||
|
||||
// The process entry of memory alloc.
|
||||
void AllocateMemory(std::vector<DeviceTensor *> *alloc_list, const DeviceContext *device_context,
|
||||
OpContext<DeviceTensor> *op_context, const AID from_aid);
|
||||
void AllocateMemory(const std::vector<DeviceTensor *> *alloc_list, const DeviceContext *device_context,
|
||||
OpContext<DeviceTensor> *const op_context, const AID from_aid);
|
||||
// The process entry of continuous memory alloc, the size of alloc_list_list, size_list_list, total_size_list and
|
||||
// device_contexts must be equal.
|
||||
void AllocateContinuousMemory(std::vector<std::vector<DeviceTensorPtr>> *alloc_list_list,
|
||||
std::vector<std::vector<size_t>> *size_list_list, std::vector<size_t> *total_size_list,
|
||||
std::vector<const DeviceContext *> *device_contexts,
|
||||
OpContext<DeviceTensor> *op_context, const AID from_aid);
|
||||
void AllocateContinuousMemory(const std::vector<std::vector<DeviceTensorPtr>> *alloc_list_list,
|
||||
const std::vector<std::vector<size_t>> *size_list_list,
|
||||
const std::vector<size_t> *total_size_list,
|
||||
const std::vector<const DeviceContext *> *device_contexts,
|
||||
OpContext<DeviceTensor> *const op_context, const AID from_aid);
|
||||
// device_contexts is from different device, the size of device_contexts must be equal to the alloc_list.
|
||||
void AllocateBatchMemory(std::vector<DeviceTensor *> *alloc_list, std::vector<const DeviceContext *> *device_contexts,
|
||||
OpContext<DeviceTensor> *op_context, const AID from_aid);
|
||||
void AllocateBatchMemory(const std::vector<DeviceTensor *> *alloc_list,
|
||||
const std::vector<const DeviceContext *> *device_contexts,
|
||||
OpContext<DeviceTensor> *const op_context, const AID from_aid);
|
||||
|
||||
// The process entry of memory free.
|
||||
void FreeMemory(std::vector<DeviceTensor *> *free_list, const DeviceContext *device_context,
|
||||
OpContext<DeviceTensor> *op_context);
|
||||
void FreeMemory(const std::vector<DeviceTensor *> *free_list, const DeviceContext *device_context,
|
||||
OpContext<DeviceTensor> *const op_context);
|
||||
// device_contexts is from different device, the size of device_contexts must be equal to the free_list.
|
||||
void FreeBatchMemory(std::vector<DeviceTensor *> *free_list, std::vector<const DeviceContext *> *device_contexts,
|
||||
OpContext<DeviceTensor> *op_context);
|
||||
void FreeBatchMemory(const std::vector<DeviceTensor *> *free_list,
|
||||
const std::vector<const DeviceContext *> *device_contexts,
|
||||
OpContext<DeviceTensor> *const op_context);
|
||||
|
||||
// Wait the MemoryManagerActor to finish running all current messages.
|
||||
void Wait(OpContext<DeviceTensor> *op_context, const AID from_aid);
|
||||
void Wait(OpContext<DeviceTensor> *const op_context, const AID from_aid);
|
||||
};
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -49,7 +49,7 @@ void OutputActor::Init() {
|
|||
}
|
||||
}
|
||||
|
||||
void OutputActor::CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *context) {
|
||||
void OutputActor::CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
current_count_ = loop_count;
|
||||
|
@ -106,7 +106,7 @@ void OutputActor::UpdateOutputDeviceAddress() {
|
|||
}
|
||||
|
||||
void OutputActor::CollectOutput(const AnfNodePtr &output_node, size_t output_index, size_t output_position,
|
||||
OpContext<DeviceTensor> *context) {
|
||||
OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
// Collect the output result in the last loop which is represented by "loop_count_ - current_count_ == 1".
|
||||
|
|
|
@ -57,11 +57,11 @@ class OutputActor : public OpActor<DeviceTensor> {
|
|||
bool IsActive(int msg_num) override { return msg_num >= running_dependent_msg_num_ ? true : false; }
|
||||
|
||||
// The output actor collects loop count when receive the input control of loop count actor.
|
||||
void CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *context);
|
||||
void CollectLoopCount(size_t loop_count, OpContext<DeviceTensor> *const context);
|
||||
|
||||
// The output actor collects output result when receive the data of actor.
|
||||
void CollectOutput(const AnfNodePtr &output_node, size_t output_index, size_t output_position,
|
||||
OpContext<DeviceTensor> *context);
|
||||
OpContext<DeviceTensor> *const context);
|
||||
|
||||
// The graph output need be set new device address every step or loop, to avoid that the device address
|
||||
// context of tensor be rewritten in the next step or next loop.
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
namespace mindspore {
|
||||
namespace runtime {
|
||||
void RecorderActor::RecordInfo(const std::string op_name, const KernelLaunchInfo *launch_info_,
|
||||
const DeviceContext *device_context, OpContext<DeviceTensor> *op_context) {
|
||||
const DeviceContext *device_context, OpContext<DeviceTensor> *const op_context) {
|
||||
MS_EXCEPTION_IF_NULL(launch_info_);
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
MS_EXCEPTION_IF_NULL(op_context);
|
||||
|
@ -62,7 +62,7 @@ void RecorderActor::RecordInfo(const std::string op_name, const KernelLaunchInfo
|
|||
#endif
|
||||
}
|
||||
|
||||
void RecorderActor::RecordOnStepEnd(OpContext<DeviceTensor> *op_context) {
|
||||
void RecorderActor::RecordOnStepEnd(OpContext<DeviceTensor> *const op_context) {
|
||||
MS_EXCEPTION_IF_NULL(op_context);
|
||||
// todo clear
|
||||
// Record iter_start, fp_start and iter_end op name and timestamp at the step end. (GPU)
|
||||
|
|
|
@ -37,11 +37,11 @@ class RecorderActor : public ActorBase {
|
|||
|
||||
// The memory recorder of each node.
|
||||
void RecordInfo(const std::string op_name, const KernelLaunchInfo *launch_info_, const DeviceContext *device_context,
|
||||
OpContext<DeviceTensor> *op_context);
|
||||
OpContext<DeviceTensor> *const op_context);
|
||||
|
||||
// Clear memory recorder at the step end.
|
||||
// Record fp_start and iter_end op name and timestamp at the step end. (GPU)
|
||||
void RecordOnStepEnd(OpContext<DeviceTensor> *op_context);
|
||||
void RecordOnStepEnd(OpContext<DeviceTensor> *const op_context);
|
||||
};
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,7 +38,7 @@ void SwitchActor::Init() {
|
|||
}
|
||||
}
|
||||
|
||||
void SwitchActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) {
|
||||
void SwitchActor::RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
const auto &sequential_num = context->sequential_num_;
|
||||
auto &input_datas = input_data_[sequential_num];
|
||||
|
@ -477,7 +477,7 @@ void SwitchActor::EraseInput(OpContext<DeviceTensor> *context) {
|
|||
}
|
||||
}
|
||||
|
||||
void SwitchActor::SendMemoryFreeReq(OpContext<DeviceTensor> *context) {
|
||||
void SwitchActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &input_device_tensors_, device_context_, context);
|
||||
}
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
|
|||
void Init() override;
|
||||
|
||||
// The switch actor run when receive the input data.
|
||||
void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context);
|
||||
void RunOpData(OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *const context);
|
||||
// The switch actor run when receive the input control.
|
||||
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context);
|
||||
// The switch actor run when receive the input branch id.
|
||||
|
@ -108,7 +108,7 @@ class SwitchActor : public SwitchActorBase<DeviceTensor> {
|
|||
void SendOutput(OpContext<DeviceTensor> *context);
|
||||
// Erase input data and input controls when finish switch launch.
|
||||
void EraseInput(OpContext<DeviceTensor> *context);
|
||||
void SendMemoryFreeReq(OpContext<DeviceTensor> *context);
|
||||
void SendMemoryFreeReq(OpContext<DeviceTensor> *const context);
|
||||
|
||||
// Collect all the backend inputs of switch actor.
|
||||
void FetchInputNode(const ControlNodeParserPtr &parser);
|
||||
|
|
|
@ -54,7 +54,7 @@ bool IsNeedInsertCopyActor(const DeviceContext *from_devcie_context, const Devic
|
|||
}
|
||||
}
|
||||
|
||||
void UpdateRefCount(DeviceTensor *device_tensor, bool is_max_ref_count = false) {
|
||||
void UpdateRefCount(DeviceTensor *const device_tensor, bool is_max_ref_count = false) {
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
if (is_max_ref_count) {
|
||||
device_tensor->set_original_ref_count(SIZE_MAX);
|
||||
|
@ -1433,14 +1433,17 @@ void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &interna
|
|||
|
||||
if (IsDeviceQueueDSActor(front_output_node)) {
|
||||
auto from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor_pair.first);
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->data_kernel_, actor_pair.second);
|
||||
LinkDataArrowForDeviceDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else if (IsKernelActor(front_output_node)) {
|
||||
auto from_actor = dynamic_cast<KernelActor *>(actor_pair.first);
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->kernel_, actor_pair.second);
|
||||
LinkDataArrowForKernelActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else if (IsHostQueueDSActor(front_output_node, graph, nullptr, host_parameters)) {
|
||||
auto from_actor = dynamic_cast<HostQueueDataSourceActor *>(actor_pair.first);
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->data_nodes_[actor_pair.second], 0);
|
||||
LinkDataArrowForHostDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else {
|
||||
|
@ -1448,7 +1451,8 @@ void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &interna
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor,
|
||||
void GraphScheduler::LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *const from_actor,
|
||||
KernelActor *const to_actor,
|
||||
KernelWithIndex from_kernel_with_output_idx,
|
||||
KernelWithIndex to_kernel_with_input_idx) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
|
@ -1473,7 +1477,8 @@ void GraphScheduler::LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *f
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_actor, KernelActor *to_actor,
|
||||
void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *const from_actor,
|
||||
KernelActor *const to_actor,
|
||||
KernelWithIndex from_kernel_with_output_idx,
|
||||
KernelWithIndex to_kernel_with_input_idx) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
|
@ -1501,7 +1506,7 @@ void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor,
|
||||
void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *const to_actor,
|
||||
KernelWithIndex from_kernel_with_output_idx,
|
||||
KernelWithIndex to_kernel_with_input_idx) {
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
|
@ -1541,7 +1546,7 @@ void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, Kernel
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *from_actor, KernelActor *to_actor,
|
||||
void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *const from_actor, KernelActor *const to_actor,
|
||||
KernelWithIndex from_kernel_with_output_idx,
|
||||
KernelWithIndex to_kernel_with_input_idx) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
|
@ -1571,14 +1576,17 @@ void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *from_actor
|
|||
auto op_arrow_to_copy = std::make_shared<DataArrow>(from_output_index, copy_actor->GetAID(), 0);
|
||||
if (IsDeviceQueueDSActor(from_kernel)) {
|
||||
auto real_from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(real_from_actor);
|
||||
from_devcie_context = real_from_actor->device_context_;
|
||||
real_from_actor->output_data_arrows_.emplace_back(op_arrow_to_copy);
|
||||
} else if (IsKernelActor(from_kernel)) {
|
||||
auto real_from_actor = dynamic_cast<KernelActor *>(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(real_from_actor);
|
||||
from_devcie_context = real_from_actor->device_context_;
|
||||
real_from_actor->output_data_arrows_.emplace_back(op_arrow_to_copy);
|
||||
} else if (IsHostQueueDSActor(from_kernel)) {
|
||||
auto real_from_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(real_from_actor);
|
||||
auto position = real_from_actor->FetchDataNodePosition(from_kernel);
|
||||
from_devcie_context = real_from_actor->device_contexts_[position];
|
||||
op_arrow_to_copy->from_output_index_ = position;
|
||||
|
|
|
@ -203,15 +203,15 @@ class GraphScheduler {
|
|||
const std::vector<AnfNodePtr> &host_parameters, const KernelGraphPtr &graph,
|
||||
KernelActor *to_actor, KernelWithIndex to_kernel_with_input_idx);
|
||||
// Link data arrows in the copy actor scene, insert the copy actor between from_actor and to_actor.
|
||||
void LinkDataArrowForCopyActor(OpActor<DeviceTensor> *from_actor, KernelActor *to_actor,
|
||||
void LinkDataArrowForCopyActor(OpActor<DeviceTensor> *const from_actor, KernelActor *const to_actor,
|
||||
KernelWithIndex from_kernel_with_output_idx, KernelWithIndex to_kernel_with_input_idx);
|
||||
void LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor,
|
||||
void LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *const from_actor, KernelActor *const to_actor,
|
||||
KernelWithIndex from_kernel_with_output_idx,
|
||||
KernelWithIndex to_to_kernel_with_input_idx);
|
||||
void LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_actor, KernelActor *to_actor,
|
||||
void LinkDataArrowForHostDSActor(HostQueueDataSourceActor *const from_actor, KernelActor *const to_actor,
|
||||
KernelWithIndex from_kernel_with_output_idx,
|
||||
KernelWithIndex to_kernel_with_input_idx);
|
||||
void LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor,
|
||||
void LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *const to_actor,
|
||||
KernelWithIndex from_kernel_with_output_idx,
|
||||
KernelWithIndex to_kernel_with_input_idx);
|
||||
|
||||
|
|
Loading…
Reference in New Issue