!26871 fix output tensor num bug

Merge pull request !26871 from caifubi/master-pynative-lazy-build-bug
This commit is contained in:
i-robot 2021-11-30 01:28:46 +00:00 committed by Gitee
commit 8932dddfd9
3 changed files with 10 additions and 1 deletions

View File

@ -15,6 +15,8 @@
*/ */
#include "backend/kernel_compiler/kernel_build_info.h" #include "backend/kernel_compiler/kernel_build_info.h"
#include <algorithm>
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "debug/anf_ir_dump.h" #include "debug/anf_ir_dump.h"
namespace mindspore { namespace mindspore {
@ -65,6 +67,11 @@ size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); }
size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); } size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); }
size_t KernelBuildInfo::GetOutputNumWithoutMonad() const {
return std::count_if(outputs_device_type_.begin(), outputs_device_type_.end(),
[](TypeId type) { return type != TypeId::kObjectTypeUMonad; });
}
std::string KernelBuildInfo::GetInputReshapeType(size_t input_index) const { std::string KernelBuildInfo::GetInputReshapeType(size_t input_index) const {
if (input_reshape_type_.empty()) { if (input_reshape_type_.empty()) {
return ""; return "";

View File

@ -94,6 +94,8 @@ class KernelBuildInfo {
size_t GetOutputNum() const; size_t GetOutputNum() const;
size_t GetOutputNumWithoutMonad() const;
std::string ToString() const; std::string ToString() const;
bool IsSimilarityKernelBuildInfo(const KernelBuildInfo &other) const; bool IsSimilarityKernelBuildInfo(const KernelBuildInfo &other) const;

View File

@ -1038,7 +1038,7 @@ size_t AnfRuntimeAlgorithm::GetOutputAddressNum(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(kernel_info); MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info(); auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info); MS_EXCEPTION_IF_NULL(build_info);
return build_info->GetOutputNum(); return build_info->GetOutputNumWithoutMonad();
} }
// set output device addr of anf_node // set output device addr of anf_node