fix reallocate memory bug for communication op
This commit is contained in:
parent
adc24f4263
commit
4526ce6845
|
@ -884,6 +884,14 @@ bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_
|
|||
return kernel_info->OutputAddrExist(output_idx);
|
||||
}
|
||||
|
||||
bool AnfRuntimeAlgorithm::WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice`
|
||||
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
return kernel_info->WorkspaceAddrExist(output_idx);
|
||||
}
|
||||
|
||||
const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
|
||||
bool visit_nop_node) {
|
||||
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
|
||||
|
|
|
@ -153,6 +153,8 @@ class AnfRuntimeAlgorithm {
|
|||
static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true);
|
||||
// check whether output addr is exist or not
|
||||
static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx);
|
||||
// check whether workspace addr is exist or not
|
||||
static bool WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx);
|
||||
// get address from prev node,input_index is the input index of current node related to prev node
|
||||
static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx,
|
||||
bool visit_nop_node = true);
|
||||
|
|
|
@ -81,6 +81,13 @@ DeviceAddressPtr KernelInfo::GetMutableWorkspaceAddr(size_t index) const {
|
|||
return workspace_address_list_[index];
|
||||
}
|
||||
|
||||
bool KernelInfo::WorkspaceAddrExist(size_t index) const {
|
||||
if (index >= workspace_address_list_.size()) {
|
||||
return false;
|
||||
}
|
||||
return workspace_address_list_[index] != nullptr;
|
||||
}
|
||||
|
||||
bool KernelInfo::SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index) {
|
||||
if (workspace_address_list_.empty()) {
|
||||
// parameter and valuenode
|
||||
|
|
|
@ -55,6 +55,7 @@ class KernelInfo : public KernelInfoDevice {
|
|||
bool SetOutputAddr(const DeviceAddressPtr &output_address, size_t index);
|
||||
DeviceAddress *GetWorkspaceAddr(size_t index) const;
|
||||
DeviceAddressPtr GetMutableWorkspaceAddr(size_t index) const;
|
||||
bool WorkspaceAddrExist(size_t index) const;
|
||||
bool SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index);
|
||||
void set_kernel_mod(const kernel::KernelModPtr &kernel_mod);
|
||||
kernel::KernelMod *MutableKernelMod() const;
|
||||
|
|
|
@ -454,8 +454,8 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
|
|||
std::vector<size_t> align_size_list;
|
||||
for (uint64_t mem_size : output_sizes) {
|
||||
if (AnfAlgo::OutputAddrExist(node, output_index++)) {
|
||||
MS_LOG(INFO) << "communication op addr exist";
|
||||
continue;
|
||||
MS_LOG(INFO) << "Communication op " << node->fullname_with_scope() << " has output device address";
|
||||
return;
|
||||
}
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
|
||||
mem_size = mem_manager_->GetCommonAlignSize(mem_size);
|
||||
|
@ -464,6 +464,10 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
|
|||
align_size_list.emplace_back(mem_size);
|
||||
}
|
||||
|
||||
if (align_size_list.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (type == kReuseDynamicMem) {
|
||||
// reuse communication op's all outputs' memory
|
||||
type = kReuseDynamicCommMem;
|
||||
|
@ -533,6 +537,10 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP
|
|||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
|
||||
auto input_node = input_node_with_index.first;
|
||||
if (AnfAlgo::OutputAddrExist(input_node, input_node_with_index.second)) {
|
||||
MS_LOG(INFO) << "Communication op " << input_node->fullname_with_scope() << " has input device address";
|
||||
return;
|
||||
}
|
||||
DeviceAddressPtr address = nullptr;
|
||||
if (input_node->isa<CNode>()) {
|
||||
address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
|
||||
|
@ -811,6 +819,10 @@ void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) {
|
|||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
size_t index = 0;
|
||||
for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
|
||||
if (AnfAlgo::WorkspaceAddrExist(node, index)) {
|
||||
MS_LOG(INFO) << "Op " << node->fullname_with_scope() << " has workspace device address";
|
||||
return;
|
||||
}
|
||||
auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size);
|
||||
AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get());
|
||||
index++;
|
||||
|
|
Loading…
Reference in New Issue