[bugfix] custom bprop got unexpected input

This commit is contained in:
lizhenyu 2021-07-07 17:31:49 +08:00
parent 81dbc13611
commit 5535e80414
5 changed files with 110 additions and 11 deletions

View File

@ -1406,6 +1406,36 @@ void SessionBasic::GetOpInputTensors(const CNodePtr &cnode,
}
}
tensor::TensorPtr SessionBasic::GetOpInputTensorByIndex(const CNodePtr &cnode,
const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs,
InputTensorInfo *input_tensor_info, size_t input_index) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(input_tensor_info);
if (input_index >= cnode->inputs().size() - 1) {
MS_LOG(EXCEPTION) << "Input index is out of range:" << cnode->inputs().size() << ",cnode:" << cnode->DebugString();
}
const auto &input = cnode->input(input_index + 1);
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
auto real_input = kernel_with_index.first;
MS_EXCEPTION_IF_NULL(real_input);
if (real_input->isa<Parameter>()) {
return GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
} else if (real_input->isa<CNode>()) {
tensor::TensorPtr tensor = GetCNodeOutputTensor(kernel_with_index, op_output);
if (AnfAlgo::IsControlOpExecInBackend(real_input)) {
CheckInputTensorShape(tensor, cnode, input_index);
}
input_tensor_info->input_kernel.insert(kernel_with_index);
return tensor;
} else {
MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
}
}
bool SessionBasic::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);

View File

@ -258,6 +258,12 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
void GetOpInputTensors(const CNodePtr &cnode, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info);
tensor::TensorPtr GetOpInputTensorByIndex(const CNodePtr &cnode,
const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs,
InputTensorInfo *input_tensor_info, size_t input_index);
// create a new kernel graph and update the graph sum
KernelGraphPtr NewKernelGraph();
AnfNodePtr CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph);

View File

@ -446,6 +446,16 @@ void GraphCompiler::GetSingleOpInputTensors(const CNodePtr &kernel,
session_->GetOpInputTensors(kernel, op_output, parameter_index, graph_inputs, input_tensor_info);
}
TensorPtr GraphCompiler::GetSingleOpInputTensorByIndex(const CNodePtr &kernel,
const std::map<KernelWithIndex, TensorPtr> &op_output,
const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<TensorPtr> &graph_inputs,
InputTensorInfo *input_tensor_info, size_t input_index) {
MS_EXCEPTION_IF_NULL(session_);
return session_->GetOpInputTensorByIndex(kernel, op_output, parameter_index, graph_inputs, input_tensor_info,
input_index);
}
void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const std::vector<TensorPtr> &input_tensors,
OpRunInfo *run_info, GraphInfo *graph_info) {
MS_EXCEPTION_IF_NULL(session_);

View File

@ -73,6 +73,11 @@ class GraphCompiler {
void GetSingleOpInputTensors(const CNodePtr &kernel, const std::map<KernelWithIndex, TensorPtr> &op_output,
const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info);
// Get one input tensor for single control op, such as bprop_cut.
TensorPtr GetSingleOpInputTensorByIndex(const CNodePtr &kernel, const std::map<KernelWithIndex, TensorPtr> &op_output,
const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<TensorPtr> &graph_inputs,
InputTensorInfo *input_tensor_info, size_t input_index);
// Get OpRunInfo and GraphInfo for single op compile and run.
void GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const std::vector<TensorPtr> &input_tensors,

View File

@ -473,6 +473,49 @@ const ActorInfo &MindRTBackend::CompileGraph(const OpRunInfo &op_run_info, const
}
namespace {
void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler, const CNodePtr &front_cnode,
const CNodePtr &backend_cnode, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output_map,
const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info,
VectorRef *args) {
MS_EXCEPTION_IF_NULL(front_cnode);
MS_EXCEPTION_IF_NULL(backend_cnode);
size_t input_index = 0;
auto inputs = front_cnode->inputs();
for (size_t i = 1; i < inputs.size(); i++) {
const auto &input_node = inputs[i];
MS_EXCEPTION_IF_NULL(input_node);
auto kernel_with_index = AnfAlgo::VisitKernel(input_node, 0);
auto real_input = kernel_with_index.first;
MS_EXCEPTION_IF_NULL(real_input);
if (!real_input->isa<ValueNode>()) {
TensorPtr tensor = graph_compiler->GetSingleOpInputTensorByIndex(backend_cnode, op_output_map, parameter_index,
graph_inputs, input_tensor_info, input_index);
MS_EXCEPTION_IF_NULL(tensor);
args->emplace_back(tensor);
input_index++;
continue;
}
// Get value from value node.
const auto &value_node = real_input->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
const auto &value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueSequeue>()) {
const auto &value_sequeue = value->cast<ValueSequeuePtr>();
MS_EXCEPTION_IF_NULL(value_sequeue);
input_index += value_sequeue->size();
} else {
input_index++;
}
args->emplace_back(value);
}
}
void PlantTensorTupleToVector(const py::tuple &tuple_inputs, std::vector<tensor::TensorPtr> *tensors) {
MS_EXCEPTION_IF_NULL(tensors);
for (const auto &input_object : tuple_inputs) {
@ -519,7 +562,10 @@ void ConvertMultiPyObjectToTensor(const py::object &input_object, std::vector<te
}
}
void RunControlOperator(const KernelGraphPtr &graph, const AnfNodePtr &kernel, std::vector<TensorPtr> *input_tensors,
void RunControlOperator(const std::shared_ptr<GraphCompiler> graph_compiler, const KernelGraphPtr &graph,
const CNodePtr &kernel, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output_map,
const std::map<AnfNodePtr, size_t> &parameter_index,
const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info,
VectorRef *op_outputs) {
AnfNodePtr front_node = graph->GetFrontAnfByBackendAnf(kernel);
MS_EXCEPTION_IF_NULL(front_node);
@ -541,8 +587,8 @@ void RunControlOperator(const KernelGraphPtr &graph, const AnfNodePtr &kernel, s
PrimitivePtr prim = GetValueNode<PrimitivePtr>(fn);
if (prim->name() == kBpropCutOpName) {
VectorRef args;
(void)std::transform(input_tensors->begin(), input_tensors->end(), std::back_inserter(args.elements_),
[](tensor::TensorPtr &tensor) { return std::move(tensor); });
GetControlOpInput(graph_compiler, cnode, kernel, op_output_map, parameter_index, graph_inputs, input_tensor_info,
&args);
BaseRef out = prim->RunHookFunction(args);
@ -666,22 +712,24 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
}
for (const auto &kernel : graph->execution_order()) {
OpRunInfo op_run_info;
GraphInfo graph_info;
InputTensorInfo input_tensor_info;
graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index],
&input_tensor_info);
graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info.input_tensors, &op_run_info,
&graph_info);
VectorRef op_outputs;
if (!AnfAlgo::IsControlOpExecInBackend(kernel)) {
OpRunInfo op_run_info;
GraphInfo graph_info;
graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index],
&input_tensor_info);
graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info.input_tensors, &op_run_info,
&graph_info);
const ActorInfo &actor_info = CompileGraph(op_run_info, graph_info, &input_tensor_info.input_tensors_mask,
&input_tensor_info.input_tensors);
RunGraph(actor_info, &op_run_info, &input_tensor_info.input_tensors_mask, &input_tensor_info.input_tensors,
&op_outputs);
} else {
RunControlOperator(graph, kernel, &input_tensor_info.input_tensors, &op_outputs);
RunControlOperator(graph_compiler_, graph, kernel, op_output_map, parameter_index, inputs[graph_index],
&input_tensor_info, &op_outputs);
}
graph_compiler_->UpdateRefCount(input_tensor_info.input_kernel, &cnode_ref_count, &op_output_map);