forked from mindspore-Ecosystem/mindspore
!203 fix reshape as output and release mem exception
Merge pull request !203 from kisnwang/optimize-allreduce-mem-malloc
This commit is contained in:
commit
18b9a0957e
|
@ -85,8 +85,10 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
|
|||
MS_EXCEPTION(DeviceProcessError) << "rtSetDevice, ret[" << static_cast<int>(ret) << "]";
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
mem_manager_->FreeDeviceMemory();
|
||||
if (mem_manager_ != nullptr) {
|
||||
mem_manager_->FreeDeviceMemory();
|
||||
}
|
||||
|
||||
(void)DestroyHccl();
|
||||
(void)ResetDevice();
|
||||
(void)ProfilingManager::GetInstance().StopProfiling();
|
||||
|
|
|
@ -101,8 +101,9 @@ void GPUKernelRuntime::ReleaseDeviceRes() {
|
|||
CHECK_OP_RET_WITH_EXCEPT(GpuBufferMgr::GetInstance().Destroy(), "Could not destroy gpu data queue.");
|
||||
}
|
||||
GPUDeviceManager::GetInstance().ReleaseDevice();
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
mem_manager_->FreeDeviceMemory();
|
||||
if (mem_manager_ != nullptr) {
|
||||
mem_manager_->FreeDeviceMemory();
|
||||
}
|
||||
}
|
||||
|
||||
void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {
|
||||
|
|
|
@ -112,6 +112,12 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr
|
|||
return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx));
|
||||
} else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
|
||||
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0);
|
||||
} else if (opt::IsNopNode(cnode)) {
|
||||
if (cnode->inputs().size() == 2) {
|
||||
return VisitKernelWithReturnType(cnode->input(1), 0);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node";
|
||||
}
|
||||
} else {
|
||||
return std::make_pair(anf_node, index);
|
||||
}
|
||||
|
@ -299,20 +305,23 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
|
|||
return build_info->GetInputFormat(input_idx);
|
||||
}
|
||||
|
||||
std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
|
||||
KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
if (!anf_node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "anf_node is not CNode.";
|
||||
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
|
||||
}
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (input_idx + 1 >= cnode->inputs().size()) {
|
||||
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode)
|
||||
<< ".";
|
||||
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
|
||||
}
|
||||
auto node = cnode->input(input_idx + 1);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
KernelWithIndex kernel_with_index = VisitKernel(node, 0);
|
||||
return VisitKernel(node, 0);
|
||||
}
|
||||
|
||||
std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
|
||||
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
|
||||
return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
|
||||
}
|
||||
|
||||
|
@ -346,18 +355,7 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &n
|
|||
}
|
||||
|
||||
std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "anf_node is not CNode.";
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (input_idx + 1 >= cnode->inputs().size()) {
|
||||
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode)
|
||||
<< ".";
|
||||
}
|
||||
auto input_node = cnode->input(input_idx + 1);
|
||||
KernelWithIndex kernel_with_index = VisitKernel(input_node, 0);
|
||||
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
|
||||
return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
|
||||
}
|
||||
|
||||
|
@ -459,17 +457,7 @@ TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_
|
|||
}
|
||||
|
||||
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << node->DebugString() << "is not a CNode";
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (input_idx + 1 >= cnode->inputs().size()) {
|
||||
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
|
||||
}
|
||||
auto input_node = cnode->input(input_idx + 1);
|
||||
KernelWithIndex kernel_with_index = VisitKernel(input_node, 0);
|
||||
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
|
||||
return AnfRuntimeAlgorithm::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
|
||||
}
|
||||
|
||||
|
@ -492,17 +480,7 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_
|
|||
}
|
||||
|
||||
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) {
|
||||
if (!anf_node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
|
||||
}
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (input_idx + 1 >= cnode->inputs().size()) {
|
||||
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
|
||||
}
|
||||
auto node = cnode->input(input_idx + 1);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
KernelWithIndex kernel_with_index = VisitKernel(node, 0);
|
||||
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
|
||||
return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
|
||||
}
|
||||
|
||||
|
@ -558,32 +536,12 @@ bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_
|
|||
}
|
||||
|
||||
const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) {
|
||||
if (!anf_node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf node is not a CNode";
|
||||
}
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (input_idx + 1 >= cnode->inputs().size()) {
|
||||
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
|
||||
}
|
||||
auto node = cnode->input(input_idx + 1);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
KernelWithIndex kernel_with_index = VisitKernel(node, 0);
|
||||
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
|
||||
return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second);
|
||||
}
|
||||
|
||||
DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) {
|
||||
if (!anf_node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
|
||||
}
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (input_idx + 1 >= cnode->inputs().size()) {
|
||||
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
|
||||
}
|
||||
auto node = cnode->input(input_idx + 1);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
KernelWithIndex kernel_with_index = VisitKernel(node, 0);
|
||||
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
|
||||
return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second);
|
||||
}
|
||||
|
||||
|
|
|
@ -89,6 +89,8 @@ class AnfRuntimeAlgorithm {
|
|||
static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx);
|
||||
// get input format select of anf node
|
||||
static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx);
|
||||
// get prev node output width output index
|
||||
static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx);
|
||||
// get output format from prev node,input_index is the input index of current node related to prev node
|
||||
static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
|
||||
// get output shapes inferred by ME from input nodes.
|
||||
|
|
Loading…
Reference in New Issue