From 5535e804140586c322f6365e38432a125ff75c24 Mon Sep 17 00:00:00 2001 From: lizhenyu Date: Wed, 7 Jul 2021 17:31:49 +0800 Subject: [PATCH] [bugfix] custom bprop got unexpected input --- .../ccsrc/backend/session/session_basic.cc | 30 ++++++++ .../ccsrc/backend/session/session_basic.h | 6 ++ .../ccsrc/runtime/framework/graph_compiler.cc | 10 +++ .../ccsrc/runtime/framework/graph_compiler.h | 5 ++ mindspore/ccsrc/vm/backend.cc | 70 ++++++++++++++++--- 5 files changed, 110 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 8c6fcc2d894..7d7df0418fb 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1406,6 +1406,36 @@ void SessionBasic::GetOpInputTensors(const CNodePtr &cnode, } } +tensor::TensorPtr SessionBasic::GetOpInputTensorByIndex(const CNodePtr &cnode, + const std::map &op_output, + const std::map ¶meter_index, + const std::vector &graph_inputs, + InputTensorInfo *input_tensor_info, size_t input_index) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(input_tensor_info); + if (input_index >= cnode->inputs().size() - 1) { + MS_LOG(EXCEPTION) << "Input index is out of range:" << cnode->inputs().size() << ",cnode:" << cnode->DebugString(); + } + + const auto &input = cnode->input(input_index + 1); + auto kernel_with_index = AnfAlgo::VisitKernel(input, 0); + auto real_input = kernel_with_index.first; + MS_EXCEPTION_IF_NULL(real_input); + + if (real_input->isa()) { + return GetParameterOutputTensor(real_input, parameter_index, graph_inputs); + } else if (real_input->isa()) { + tensor::TensorPtr tensor = GetCNodeOutputTensor(kernel_with_index, op_output); + if (AnfAlgo::IsControlOpExecInBackend(real_input)) { + CheckInputTensorShape(tensor, cnode, input_index); + } + input_tensor_info->input_kernel.insert(kernel_with_index); + return tensor; + } else { + MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString(); + } +} + bool SessionBasic::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(graph); diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 3a75c43b63e..2ba31599b9d 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -258,6 +258,12 @@ class SessionBasic : public std::enable_shared_from_this { void GetOpInputTensors(const CNodePtr &cnode, const std::map &op_output, const std::map ¶meter_index, const std::vector &graph_inputs, InputTensorInfo *input_tensor_info); + tensor::TensorPtr GetOpInputTensorByIndex(const CNodePtr &cnode, + const std::map &op_output, + const std::map ¶meter_index, + const std::vector &graph_inputs, + InputTensorInfo *input_tensor_info, size_t input_index); + // create a new kernel graph and update the graph sum KernelGraphPtr NewKernelGraph(); AnfNodePtr CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph); diff --git a/mindspore/ccsrc/runtime/framework/graph_compiler.cc b/mindspore/ccsrc/runtime/framework/graph_compiler.cc index 2fae9f89ffc..2fe7caf1372 100644 --- a/mindspore/ccsrc/runtime/framework/graph_compiler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_compiler.cc @@ -446,6 +446,16 @@ void GraphCompiler::GetSingleOpInputTensors(const CNodePtr &kernel, session_->GetOpInputTensors(kernel, op_output, parameter_index, graph_inputs, input_tensor_info); } +TensorPtr GraphCompiler::GetSingleOpInputTensorByIndex(const CNodePtr &kernel, + const std::map &op_output, + const std::map ¶meter_index, + const std::vector &graph_inputs, + InputTensorInfo *input_tensor_info, size_t input_index) { + MS_EXCEPTION_IF_NULL(session_); + return session_->GetOpInputTensorByIndex(kernel, op_output, parameter_index, graph_inputs, input_tensor_info, + input_index); +} + void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const std::vector &input_tensors, OpRunInfo *run_info, GraphInfo *graph_info) { MS_EXCEPTION_IF_NULL(session_); diff --git a/mindspore/ccsrc/runtime/framework/graph_compiler.h b/mindspore/ccsrc/runtime/framework/graph_compiler.h index e29f00842d6..50cc891232a 100644 --- a/mindspore/ccsrc/runtime/framework/graph_compiler.h +++ b/mindspore/ccsrc/runtime/framework/graph_compiler.h @@ -73,6 +73,11 @@ class GraphCompiler { void GetSingleOpInputTensors(const CNodePtr &kernel, const std::map &op_output, const std::map ¶meter_index, const std::vector &graph_inputs, InputTensorInfo *input_tensor_info); + // Get one input tensor for single control op, such as bprop_cut. + TensorPtr GetSingleOpInputTensorByIndex(const CNodePtr &kernel, const std::map &op_output, + const std::map ¶meter_index, + const std::vector &graph_inputs, + InputTensorInfo *input_tensor_info, size_t input_index); // Get OpRunInfo and GraphInfo for single op compile and run. void GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const std::vector &input_tensors, diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index e100ae01509..35373088a77 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -473,6 +473,49 @@ const ActorInfo &MindRTBackend::CompileGraph(const OpRunInfo &op_run_info, const } namespace { +void GetControlOpInput(const std::shared_ptr &graph_compiler, const CNodePtr &front_cnode, + const CNodePtr &backend_cnode, const std::map &op_output_map, + const std::map ¶meter_index, + const std::vector &graph_inputs, InputTensorInfo *input_tensor_info, + VectorRef *args) { + MS_EXCEPTION_IF_NULL(front_cnode); + MS_EXCEPTION_IF_NULL(backend_cnode); + size_t input_index = 0; + auto inputs = front_cnode->inputs(); + for (size_t i = 1; i < inputs.size(); i++) { + const auto &input_node = inputs[i]; + MS_EXCEPTION_IF_NULL(input_node); + auto kernel_with_index = AnfAlgo::VisitKernel(input_node, 0); + auto real_input = kernel_with_index.first; + MS_EXCEPTION_IF_NULL(real_input); + + if (!real_input->isa()) { + TensorPtr tensor = graph_compiler->GetSingleOpInputTensorByIndex(backend_cnode, op_output_map, parameter_index, + graph_inputs, input_tensor_info, input_index); + MS_EXCEPTION_IF_NULL(tensor); + args->emplace_back(tensor); + input_index++; + continue; + } + + // Get value from value node. + const auto &value_node = real_input->cast(); + MS_EXCEPTION_IF_NULL(value_node); + const auto &value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + + if (value->isa()) { + const auto &value_sequeue = value->cast(); + MS_EXCEPTION_IF_NULL(value_sequeue); + input_index += value_sequeue->size(); + } else { + input_index++; + } + + args->emplace_back(value); + } +} + void PlantTensorTupleToVector(const py::tuple &tuple_inputs, std::vector *tensors) { MS_EXCEPTION_IF_NULL(tensors); for (const auto &input_object : tuple_inputs) { @@ -519,7 +562,10 @@ void ConvertMultiPyObjectToTensor(const py::object &input_object, std::vector *input_tensors, +void RunControlOperator(const std::shared_ptr graph_compiler, const KernelGraphPtr &graph, + const CNodePtr &kernel, const std::map &op_output_map, + const std::map ¶meter_index, + const std::vector &graph_inputs, InputTensorInfo *input_tensor_info, VectorRef *op_outputs) { AnfNodePtr front_node = graph->GetFrontAnfByBackendAnf(kernel); MS_EXCEPTION_IF_NULL(front_node); @@ -541,8 +587,8 @@ void RunControlOperator(const KernelGraphPtr &graph, const AnfNodePtr &kernel, s PrimitivePtr prim = GetValueNode(fn); if (prim->name() == kBpropCutOpName) { VectorRef args; - (void)std::transform(input_tensors->begin(), input_tensors->end(), std::back_inserter(args.elements_), - [](tensor::TensorPtr &tensor) { return std::move(tensor); }); + GetControlOpInput(graph_compiler, cnode, kernel, op_output_map, parameter_index, graph_inputs, input_tensor_info, + &args); BaseRef out = prim->RunHookFunction(args); @@ -666,22 +712,24 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector &graphs } for (const auto &kernel : graph->execution_order()) { - OpRunInfo op_run_info; - GraphInfo graph_info; InputTensorInfo input_tensor_info; - graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index], - &input_tensor_info); - graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info.input_tensors, &op_run_info, - &graph_info); - VectorRef op_outputs; + if (!AnfAlgo::IsControlOpExecInBackend(kernel)) { + OpRunInfo op_run_info; + GraphInfo graph_info; + graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index], + &input_tensor_info); + graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info.input_tensors, &op_run_info, + &graph_info); + const ActorInfo &actor_info = CompileGraph(op_run_info, graph_info, &input_tensor_info.input_tensors_mask, &input_tensor_info.input_tensors); RunGraph(actor_info, &op_run_info, &input_tensor_info.input_tensors_mask, &input_tensor_info.input_tensors, &op_outputs); } else { - RunControlOperator(graph, kernel, &input_tensor_info.input_tensors, &op_outputs); + RunControlOperator(graph_compiler_, graph, kernel, op_output_map, parameter_index, inputs[graph_index], + &input_tensor_info, &op_outputs); } graph_compiler_->UpdateRefCount(input_tensor_info.input_kernel, &cnode_ref_count, &op_output_map);