!203 fix reshape as output and release mem exception

Merge pull request !203 from kisnwang/optimize-allreduce-mem-malloc
This commit is contained in:
mindspore-ci-bot 2020-04-13 15:26:06 +08:00 committed by Gitee
commit 18b9a0957e
4 changed files with 28 additions and 65 deletions

View File

@ -85,8 +85,10 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
MS_EXCEPTION(DeviceProcessError) << "rtSetDevice, ret[" << static_cast<int>(ret) << "]";
}
MS_EXCEPTION_IF_NULL(mem_manager_);
mem_manager_->FreeDeviceMemory();
if (mem_manager_ != nullptr) {
mem_manager_->FreeDeviceMemory();
}
(void)DestroyHccl();
(void)ResetDevice();
(void)ProfilingManager::GetInstance().StopProfiling();

View File

@ -101,8 +101,9 @@ void GPUKernelRuntime::ReleaseDeviceRes() {
CHECK_OP_RET_WITH_EXCEPT(GpuBufferMgr::GetInstance().Destroy(), "Could not destroy gpu data queue.");
}
GPUDeviceManager::GetInstance().ReleaseDevice();
MS_EXCEPTION_IF_NULL(mem_manager_);
mem_manager_->FreeDeviceMemory();
if (mem_manager_ != nullptr) {
mem_manager_->FreeDeviceMemory();
}
}
void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {

View File

@ -112,6 +112,12 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr
return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx));
} else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0);
} else if (opt::IsNopNode(cnode)) {
if (cnode->inputs().size() == 2) {
return VisitKernelWithReturnType(cnode->input(1), 0);
} else {
MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node";
}
} else {
return std::make_pair(anf_node, index);
}
@ -299,20 +305,23 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
return build_info->GetInputFormat(input_idx);
}
std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(anf_node);
if (!anf_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "anf_node is not CNode.";
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (input_idx + 1 >= cnode->inputs().size()) {
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode)
<< ".";
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
}
auto node = cnode->input(input_idx + 1);
MS_EXCEPTION_IF_NULL(node);
KernelWithIndex kernel_with_index = VisitKernel(node, 0);
return VisitKernel(node, 0);
}
std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
}
@ -346,18 +355,7 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &n
}
std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "anf_node is not CNode.";
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (input_idx + 1 >= cnode->inputs().size()) {
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode)
<< ".";
}
auto input_node = cnode->input(input_idx + 1);
KernelWithIndex kernel_with_index = VisitKernel(input_node, 0);
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
}
@ -459,17 +457,7 @@ TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_
}
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << node->DebugString() << "is not a CNode";
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (input_idx + 1 >= cnode->inputs().size()) {
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
}
auto input_node = cnode->input(input_idx + 1);
KernelWithIndex kernel_with_index = VisitKernel(input_node, 0);
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
return AnfRuntimeAlgorithm::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
}
@ -492,17 +480,7 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_
}
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) {
if (!anf_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (input_idx + 1 >= cnode->inputs().size()) {
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
}
auto node = cnode->input(input_idx + 1);
MS_EXCEPTION_IF_NULL(node);
KernelWithIndex kernel_with_index = VisitKernel(node, 0);
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
}
@ -558,32 +536,12 @@ bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_
}
const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) {
if (!anf_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf node is not a CNode";
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (input_idx + 1 >= cnode->inputs().size()) {
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
}
auto node = cnode->input(input_idx + 1);
MS_EXCEPTION_IF_NULL(node);
KernelWithIndex kernel_with_index = VisitKernel(node, 0);
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second);
}
DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) {
if (!anf_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (input_idx + 1 >= cnode->inputs().size()) {
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
}
auto node = cnode->input(input_idx + 1);
MS_EXCEPTION_IF_NULL(node);
KernelWithIndex kernel_with_index = VisitKernel(node, 0);
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second);
}

View File

@ -89,6 +89,8 @@ class AnfRuntimeAlgorithm {
static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx);
// get input format select of anf node
static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx);
// get prev node output width output index
static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx);
// get output format from prev node,input_index is the input index of current node related to prev node
static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
// get output shapes inferred by ME from input nodes.