Fix GetOpInputTensors bug

This commit is contained in:
tanghuikang 2021-01-07 11:13:24 +08:00
parent 25d5d43dea
commit 92cb18faab
1 changed files with 4 additions and 3 deletions

View File

@ -268,10 +268,10 @@ void GetOpInputTensors(const CNodePtr &cnode, const std::map<KernelWithIndex, te
MS_EXCEPTION_IF_NULL(real_input); MS_EXCEPTION_IF_NULL(real_input);
tensor::TensorPtr tensor = nullptr; tensor::TensorPtr tensor = nullptr;
if (real_input->isa<ValueNode>()) { if (real_input->isa<ValueNode>()) {
auto value_node = input->cast<ValueNodePtr>(); auto value_node = real_input->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node); MS_EXCEPTION_IF_NULL(value_node);
auto value = GetValueNode(value_node); auto value = GetValueNode(value_node);
MS_EXCEPTION_IF_NULL(value_node); MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueTuple>()) { if (value->isa<ValueTuple>()) {
auto value_tuple = value->cast<ValueTuplePtr>(); auto value_tuple = value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple); MS_EXCEPTION_IF_NULL(value_tuple);
@ -881,7 +881,7 @@ void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_inf
void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) { VectorRef *outputs) {
MS_LOG(INFO) << "Start"; MS_LOG(INFO) << "Start!";
auto kernel_graph = GetGraph(graph_id); auto kernel_graph = GetGraph(graph_id);
std::map<AnfNodePtr, size_t> parameter_index; std::map<AnfNodePtr, size_t> parameter_index;
GetParameterIndex(kernel_graph.get(), inputs, &parameter_index); GetParameterIndex(kernel_graph.get(), inputs, &parameter_index);
@ -910,6 +910,7 @@ void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector
HandleOpInputs(input_tensor_info.input_kernel, &cnode_ref, &op_output_map); HandleOpInputs(input_tensor_info.input_kernel, &cnode_ref, &op_output_map);
HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_ref, &op_output_map, outputs); HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_ref, &op_output_map, outputs);
} }
MS_LOG(INFO) << "Finish!";
} }
// compile graph steps // compile graph steps