forked from mindspore-Ecosystem/mindspore
[bugfix] custom bprop got unexpected input
This commit is contained in:
parent
81dbc13611
commit
5535e80414
|
@ -1406,6 +1406,36 @@ void SessionBasic::GetOpInputTensors(const CNodePtr &cnode,
|
|||
}
|
||||
}
|
||||
|
||||
tensor::TensorPtr SessionBasic::GetOpInputTensorByIndex(const CNodePtr &cnode,
|
||||
const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
|
||||
const std::map<AnfNodePtr, size_t> ¶meter_index,
|
||||
const std::vector<tensor::TensorPtr> &graph_inputs,
|
||||
InputTensorInfo *input_tensor_info, size_t input_index) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(input_tensor_info);
|
||||
if (input_index >= cnode->inputs().size() - 1) {
|
||||
MS_LOG(EXCEPTION) << "Input index is out of range:" << cnode->inputs().size() << ",cnode:" << cnode->DebugString();
|
||||
}
|
||||
|
||||
const auto &input = cnode->input(input_index + 1);
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
|
||||
auto real_input = kernel_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
|
||||
if (real_input->isa<Parameter>()) {
|
||||
return GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
|
||||
} else if (real_input->isa<CNode>()) {
|
||||
tensor::TensorPtr tensor = GetCNodeOutputTensor(kernel_with_index, op_output);
|
||||
if (AnfAlgo::IsControlOpExecInBackend(real_input)) {
|
||||
CheckInputTensorShape(tensor, cnode, input_index);
|
||||
}
|
||||
input_tensor_info->input_kernel.insert(kernel_with_index);
|
||||
return tensor;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
|
||||
}
|
||||
}
|
||||
|
||||
bool SessionBasic::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
|
|
@ -258,6 +258,12 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
void GetOpInputTensors(const CNodePtr &cnode, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
|
||||
const std::map<AnfNodePtr, size_t> ¶meter_index,
|
||||
const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info);
|
||||
tensor::TensorPtr GetOpInputTensorByIndex(const CNodePtr &cnode,
|
||||
const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
|
||||
const std::map<AnfNodePtr, size_t> ¶meter_index,
|
||||
const std::vector<tensor::TensorPtr> &graph_inputs,
|
||||
InputTensorInfo *input_tensor_info, size_t input_index);
|
||||
|
||||
// create a new kernel graph and update the graph sum
|
||||
KernelGraphPtr NewKernelGraph();
|
||||
AnfNodePtr CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph);
|
||||
|
|
|
@ -446,6 +446,16 @@ void GraphCompiler::GetSingleOpInputTensors(const CNodePtr &kernel,
|
|||
session_->GetOpInputTensors(kernel, op_output, parameter_index, graph_inputs, input_tensor_info);
|
||||
}
|
||||
|
||||
TensorPtr GraphCompiler::GetSingleOpInputTensorByIndex(const CNodePtr &kernel,
|
||||
const std::map<KernelWithIndex, TensorPtr> &op_output,
|
||||
const std::map<AnfNodePtr, size_t> ¶meter_index,
|
||||
const std::vector<TensorPtr> &graph_inputs,
|
||||
InputTensorInfo *input_tensor_info, size_t input_index) {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
return session_->GetOpInputTensorByIndex(kernel, op_output, parameter_index, graph_inputs, input_tensor_info,
|
||||
input_index);
|
||||
}
|
||||
|
||||
void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const std::vector<TensorPtr> &input_tensors,
|
||||
OpRunInfo *run_info, GraphInfo *graph_info) {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
|
|
|
@ -73,6 +73,11 @@ class GraphCompiler {
|
|||
void GetSingleOpInputTensors(const CNodePtr &kernel, const std::map<KernelWithIndex, TensorPtr> &op_output,
|
||||
const std::map<AnfNodePtr, size_t> ¶meter_index,
|
||||
const std::vector<TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info);
|
||||
// Get one input tensor for single control op, such as bprop_cut.
|
||||
TensorPtr GetSingleOpInputTensorByIndex(const CNodePtr &kernel, const std::map<KernelWithIndex, TensorPtr> &op_output,
|
||||
const std::map<AnfNodePtr, size_t> ¶meter_index,
|
||||
const std::vector<TensorPtr> &graph_inputs,
|
||||
InputTensorInfo *input_tensor_info, size_t input_index);
|
||||
|
||||
// Get OpRunInfo and GraphInfo for single op compile and run.
|
||||
void GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const std::vector<TensorPtr> &input_tensors,
|
||||
|
|
|
@ -473,6 +473,49 @@ const ActorInfo &MindRTBackend::CompileGraph(const OpRunInfo &op_run_info, const
|
|||
}
|
||||
|
||||
namespace {
|
||||
void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler, const CNodePtr &front_cnode,
|
||||
const CNodePtr &backend_cnode, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output_map,
|
||||
const std::map<AnfNodePtr, size_t> ¶meter_index,
|
||||
const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info,
|
||||
VectorRef *args) {
|
||||
MS_EXCEPTION_IF_NULL(front_cnode);
|
||||
MS_EXCEPTION_IF_NULL(backend_cnode);
|
||||
size_t input_index = 0;
|
||||
auto inputs = front_cnode->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
const auto &input_node = inputs[i];
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(input_node, 0);
|
||||
auto real_input = kernel_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
|
||||
if (!real_input->isa<ValueNode>()) {
|
||||
TensorPtr tensor = graph_compiler->GetSingleOpInputTensorByIndex(backend_cnode, op_output_map, parameter_index,
|
||||
graph_inputs, input_tensor_info, input_index);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
args->emplace_back(tensor);
|
||||
input_index++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get value from value node.
|
||||
const auto &value_node = real_input->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
const auto &value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
|
||||
if (value->isa<ValueSequeue>()) {
|
||||
const auto &value_sequeue = value->cast<ValueSequeuePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_sequeue);
|
||||
input_index += value_sequeue->size();
|
||||
} else {
|
||||
input_index++;
|
||||
}
|
||||
|
||||
args->emplace_back(value);
|
||||
}
|
||||
}
|
||||
|
||||
void PlantTensorTupleToVector(const py::tuple &tuple_inputs, std::vector<tensor::TensorPtr> *tensors) {
|
||||
MS_EXCEPTION_IF_NULL(tensors);
|
||||
for (const auto &input_object : tuple_inputs) {
|
||||
|
@ -519,7 +562,10 @@ void ConvertMultiPyObjectToTensor(const py::object &input_object, std::vector<te
|
|||
}
|
||||
}
|
||||
|
||||
void RunControlOperator(const KernelGraphPtr &graph, const AnfNodePtr &kernel, std::vector<TensorPtr> *input_tensors,
|
||||
void RunControlOperator(const std::shared_ptr<GraphCompiler> graph_compiler, const KernelGraphPtr &graph,
|
||||
const CNodePtr &kernel, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output_map,
|
||||
const std::map<AnfNodePtr, size_t> ¶meter_index,
|
||||
const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info,
|
||||
VectorRef *op_outputs) {
|
||||
AnfNodePtr front_node = graph->GetFrontAnfByBackendAnf(kernel);
|
||||
MS_EXCEPTION_IF_NULL(front_node);
|
||||
|
@ -541,8 +587,8 @@ void RunControlOperator(const KernelGraphPtr &graph, const AnfNodePtr &kernel, s
|
|||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(fn);
|
||||
if (prim->name() == kBpropCutOpName) {
|
||||
VectorRef args;
|
||||
(void)std::transform(input_tensors->begin(), input_tensors->end(), std::back_inserter(args.elements_),
|
||||
[](tensor::TensorPtr &tensor) { return std::move(tensor); });
|
||||
GetControlOpInput(graph_compiler, cnode, kernel, op_output_map, parameter_index, graph_inputs, input_tensor_info,
|
||||
&args);
|
||||
|
||||
BaseRef out = prim->RunHookFunction(args);
|
||||
|
||||
|
@ -666,22 +712,24 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
|
|||
}
|
||||
|
||||
for (const auto &kernel : graph->execution_order()) {
|
||||
OpRunInfo op_run_info;
|
||||
GraphInfo graph_info;
|
||||
InputTensorInfo input_tensor_info;
|
||||
graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index],
|
||||
&input_tensor_info);
|
||||
graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info.input_tensors, &op_run_info,
|
||||
&graph_info);
|
||||
|
||||
VectorRef op_outputs;
|
||||
|
||||
if (!AnfAlgo::IsControlOpExecInBackend(kernel)) {
|
||||
OpRunInfo op_run_info;
|
||||
GraphInfo graph_info;
|
||||
graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index],
|
||||
&input_tensor_info);
|
||||
graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info.input_tensors, &op_run_info,
|
||||
&graph_info);
|
||||
|
||||
const ActorInfo &actor_info = CompileGraph(op_run_info, graph_info, &input_tensor_info.input_tensors_mask,
|
||||
&input_tensor_info.input_tensors);
|
||||
RunGraph(actor_info, &op_run_info, &input_tensor_info.input_tensors_mask, &input_tensor_info.input_tensors,
|
||||
&op_outputs);
|
||||
} else {
|
||||
RunControlOperator(graph, kernel, &input_tensor_info.input_tensors, &op_outputs);
|
||||
RunControlOperator(graph_compiler_, graph, kernel, op_output_map, parameter_index, inputs[graph_index],
|
||||
&input_tensor_info, &op_outputs);
|
||||
}
|
||||
|
||||
graph_compiler_->UpdateRefCount(input_tensor_info.input_kernel, &cnode_ref_count, &op_output_map);
|
||||
|
|
Loading…
Reference in New Issue