fix output tensor num bug
This commit is contained in:
parent
bd6de7162b
commit
1f25ce5d98
|
@ -1051,6 +1051,15 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNode
|
|||
return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, skip_nop_node);
|
||||
}
|
||||
|
||||
size_t AnfRuntimeAlgorithm::GetOutputAddressNum(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
auto build_info = kernel_info->select_kernel_build_info();
|
||||
MS_EXCEPTION_IF_NULL(build_info);
|
||||
return build_info->GetOutputNum();
|
||||
}
|
||||
|
||||
// set output device addr of anf_node
|
||||
void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
|
|
@ -199,6 +199,7 @@ class AnfRuntimeAlgorithm {
|
|||
bool skip_nop_node = true);
|
||||
static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
|
||||
bool skip_nop_node = true);
|
||||
static size_t GetOutputAddressNum(const AnfNodePtr &node);
|
||||
// set output device addr of anf_node
|
||||
static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node);
|
||||
// set workspace device addr of anf_node
|
||||
|
|
|
@ -182,7 +182,7 @@ void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const
|
|||
continue;
|
||||
}
|
||||
|
||||
auto output_size = AnfAlgo::GetOutputTensorNum(kernel);
|
||||
auto output_size = AnfAlgo::GetOutputAddressNum(kernel);
|
||||
for (size_t i = 0; i < output_size; ++i) {
|
||||
if (AnfAlgo::OutputAddrExist(kernel, i)) {
|
||||
continue;
|
||||
|
|
Loading…
Reference in New Issue