!39044 Fix error log and refactor func for dynamic shape.

Merge pull request !39044 from TronZhang/fix_error_log_from_1.8_to_master
This commit is contained in:
i-robot 2022-08-03 09:30:56 +00:00 committed by Gitee
commit f069b6372a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 13 additions and 7 deletions

View File

@ -19,6 +19,7 @@
#include <memory>
#include <stack>
#include <set>
#include <string>
#include <vector>
#include <map>
#include <utility>
@ -95,12 +96,15 @@ bool InferShapeForDefiniteOutputNode(const CNodePtr &cnode) {
return true;
}
tensor::TensorPtr GetDependValueTensor(size_t i, const AnfNodePtr &real_input, size_t real_input_index,
bool skip_nop_node, void *args) {
tensor::TensorPtr GetDependValueTensor(const AnfNodePtr &node, size_t i,
const std::pair<AnfNodePtr, size_t> &input_node_with_index, bool skip_nop_node,
void *args) {
auto real_input = input_node_with_index.first;
MS_EXCEPTION_IF_NULL(real_input);
auto real_input_index = input_node_with_index.second;
auto shapes = trans::GetRuntimePaddingShape(real_input, real_input_index);
auto host_type = common::AnfAlgo::GetOutputInferDataType(real_input, real_input_index);
auto out_tensor = std::make_shared<tensor::Tensor>(host_type, shapes);
MS_EXCEPTION_IF_NULL(out_tensor);
auto output_addr = AnfAlgo::GetMutableOutputAddr(real_input, real_input_index, skip_nop_node);
if (output_addr != nullptr && output_addr->GetPtr() != nullptr) {
@ -111,11 +115,12 @@ tensor::TensorPtr GetDependValueTensor(size_t i, const AnfNodePtr &real_input, s
} else {
// If real_input is parameter and is control flow's output, the device address stored in AnfNode is useless.
if (args == nullptr) {
MS_LOG(EXCEPTION) << "Address is nullptr, something error";
MS_LOG(EXCEPTION) << "Address is nullptr, and no valid address args is passed!";
}
auto input_device_address = reinterpret_cast<std::vector<device::DeviceAddress *> *>(args);
if (i >= input_device_address->size() || input_device_address->at(i) == nullptr) {
MS_LOG(EXCEPTION) << "i is nullptr, something error";
MS_EXCEPTION_IF_NULL(node);
MS_LOG(EXCEPTION) << "There is no valid address for " << i << " input of " << node->fullname_with_scope();
}
out_tensor->data_sync_directly(input_device_address->at(i));
@ -154,7 +159,7 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
InferShapeForNopNode(real_input);
}
if (depend_list.find(i) != depend_list.end()) {
auto out_tensor = GetDependValueTensor(i, real_input, real_input_index, skip_nop_node, args);
auto out_tensor = GetDependValueTensor(cnode, i, input_node_with_index, skip_nop_node, args);
auto ret2 = depend_tensor_map->try_emplace(i, out_tensor);
if (!ret2.second) {
MS_LOG(EXCEPTION) << "Insert map failed.";
@ -177,6 +182,8 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
}
auto eval_result = opt::CppInferShapeAndType(primitive, args_spec_list);
MS_LOG(DEBUG) << "Infer result of " << cnode->fullname_with_scope() << " is: " << eval_result;
cnode->set_abstract(eval_result);
}
} // namespace

View File

@ -17,7 +17,6 @@
#ifndef MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_DYNAMIC_SHAPE_DYNAMIC_SHAPE_HELPER_H
#define MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_DYNAMIC_SHAPE_DYNAMIC_SHAPE_HELPER_H
#include <string>
#include "ir/anf.h"
#include "utils/ms_utils.h"
#include "backend/common/optimizer/optimizer.h"