fix output tensor num bug

This commit is contained in:
caifubi 2021-11-26 11:47:29 +08:00
parent bd6de7162b
commit 1f25ce5d98
3 changed files with 11 additions and 1 deletions

View File

@ -1051,6 +1051,15 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNode
return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, skip_nop_node);
}
size_t AnfRuntimeAlgorithm::GetOutputAddressNum(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
return build_info->GetOutputNum();
}
// set output device addr of anf_node
void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);

View File

@ -199,6 +199,7 @@ class AnfRuntimeAlgorithm {
bool skip_nop_node = true);
static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
bool skip_nop_node = true);
static size_t GetOutputAddressNum(const AnfNodePtr &node);
// set output device addr of anf_node
static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node);
// set workspace device addr of anf_node

View File

@ -182,7 +182,7 @@ void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const
continue;
}
auto output_size = AnfAlgo::GetOutputTensorNum(kernel);
auto output_size = AnfAlgo::GetOutputAddressNum(kernel);
for (size_t i = 0; i < output_size; ++i) {
if (AnfAlgo::OutputAddrExist(kernel, i)) {
continue;