!28736 Fix bug of PyNative MindRT ValueNode

Merge pull request !28736 from caifubi/master-pynative-mindrt-valuenode
This commit is contained in:
i-robot 2022-01-10 06:34:43 +00:00 committed by Gitee
commit ba610448ca
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 7 additions and 9 deletions

View File

@ -280,8 +280,7 @@ void UpdateInputDeviceAddress(const KernelGraphPtr &graph) {
}
std::vector<tensor::TensorPtr> GetRealValueNodeTensorFromGraph(
const KernelGraphPtr &graph, size_t input_tensors_size,
const std::vector<tensor::TensorPtr> &tensors_without_value_node) {
const KernelGraphPtr &graph, const std::vector<tensor::TensorPtr> &tensors_without_value_node) {
std::vector<tensor::TensorPtr> new_input_tensors;
if (graph->execution_order().size() != 1) {
return new_input_tensors;
@ -289,12 +288,12 @@ std::vector<tensor::TensorPtr> GetRealValueNodeTensorFromGraph(
const auto &node = graph->execution_order().back();
auto input_num = AnfAlgo::GetInputTensorNum(node);
// In most scenarios, input_num and input_tensors_size are equal.
// Except for special procedures, new ValueNode will be added to Graph in GraphOptimize.
if (input_num == input_tensors_size) {
// No value node in graph
if (input_num == tensors_without_value_node.size()) {
return new_input_tensors;
}
MS_LOG(INFO) << "CNode input num:" << input_num << " input_tensors size:" << input_tensors_size;
MS_LOG(DEBUG) << "CNode input num:" << input_num
<< " tensors_without_value_node size:" << tensors_without_value_node.size();
std::map<size_t, tensor::TensorPtr> value_node_pos;
for (size_t i = 0; i < input_num; ++i) {
@ -320,7 +319,7 @@ std::vector<tensor::TensorPtr> GetRealValueNodeTensorFromGraph(
new_input_tensors.emplace_back(iter->second);
}
}
MS_LOG(INFO) << "new input tensor size:" << new_input_tensors.size();
MS_LOG(DEBUG) << "new input tensor size:" << new_input_tensors.size();
return new_input_tensors;
}
@ -1218,8 +1217,7 @@ void MindRTBackend::RunSingleOpGraph(const KernelGraphPtr &graph,
}
}
std::vector<tensor::TensorPtr> new_input_tensors =
GetRealValueNodeTensorFromGraph(graph, input_tensors.size(), tensors_without_value_node);
std::vector<tensor::TensorPtr> new_input_tensors = GetRealValueNodeTensorFromGraph(graph, tensors_without_value_node);
for (auto &tensor : tensors_without_value_node) {
MS_EXCEPTION_IF_NULL(tensor);