!49282 Fix some bug for PyNative

Merge pull request !49282 from zjun/2002
This commit is contained in:
i-robot 2023-02-23 08:27:53 +00:00 committed by Gitee
commit 5de427bfed
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 28 additions and 18 deletions

View File

@ -558,6 +558,11 @@ void SessionBasic::GetParameterIndex(const KernelGraph *graph, const std::vector
parameter_index->emplace(param, index++);
continue;
}
// Input is scalar. param shape will be [1], input shape will be []
if (param_shape.size() == 1 && input_shape.empty()) {
parameter_index->emplace(param, index++);
continue;
}
MS_LOG(EXCEPTION) << "Shape size of input tensor(" << input_shape << ") and parameter(" << param_shape
<< ") are different, input index: " << index << ", parameter: " << param->DebugString();
}

View File

@ -581,17 +581,19 @@ FuncGraphPtr AutoGradCellImpl::GradFuncGraph(const GradParamPtr &grad_param) {
GradGraphByExpander(grad_param);
// Set dout parameter
if (kMonadPrim.find(GetCNodePrimitive(ad_param()->last_node_)) != kMonadPrim.end()) {
ad_param()->last_node_ = common::AnfAlgo::VisitKernelWithReturnType(ad_param()->last_node_, 0, false,
{prim::kPrimTupleGetItem, prim::kPrimMakeTuple})
.first;
}
auto ad_graph_dout = ad_param()->tape_->add_parameter();
ad_graph_dout->set_abstract(ad_param()->last_node_->abstract());
ad_param()->anfnode_to_variable_adjoint_.at(ad_param()->last_node_)->fn()->UpdateAccumulativeDout(ad_graph_dout);
if (ad_param()->last_node_ != nullptr) {
// Set dout parameter
if (kMonadPrim.find(GetCNodePrimitive(ad_param()->last_node_)) != kMonadPrim.end()) {
ad_param()->last_node_ = common::AnfAlgo::VisitKernelWithReturnType(
ad_param()->last_node_, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple})
.first;
}
auto ad_graph_dout = ad_param()->tape_->add_parameter();
ad_graph_dout->set_abstract(ad_param()->last_node_->abstract());
ad_param()->anfnode_to_variable_adjoint_.at(ad_param()->last_node_)->fn()->UpdateAccumulativeDout(ad_graph_dout);
(void)BackPropagate();
(void)BackPropagate();
}
AnfNodePtrList outputs{NewValueNode(prim::kPrimMakeTuple)};
abstract::AbstractBasePtrList out_abs_list;
@ -655,7 +657,9 @@ void AutoGradCellImpl::GradGraphByExpander(const GradParamPtr &grad_param) {
k_node->set_abstract(cnode->abstract());
// In ms function, copy forward graph cnode info to bprop graph
if (ms_function_by_value && cnode->forward().first != nullptr) {
k_node->set_forward(cnode->forward().first, cnode->forward().second);
auto new_v_node = NewValueNode(cnode->forward().first->value());
new_v_node->set_abstract(cnode->forward().first->abstract());
k_node->set_forward(new_v_node, cnode->forward().second);
ad_param()->tape_->set_used_forward_nodes({k_node});
}
MS_LOG(DEBUG) << "Build knode " << k_node->DebugString();
@ -856,14 +860,15 @@ AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const AnfNodePtr &input_nod
} else if (IsPrimitiveCNode(input_node, prim::kPrimTupleGetItem)) {
return BuildKNodeForTupleGetItem(input_node);
}
MS_LOG(EXCEPTION) << "Cannot find input in adjoint map, inp: " << input_node->DebugString();
MS_LOG(EXCEPTION) << "Can not find input in adjoint map, inp: " << input_node->DebugString();
}
return input_adjoint_iter->second->k_node();
} else {
// Tuple sens will come in
if (input_node->isa<Parameter>() && input_node->abstract()->isa<abstract::AbstractSequence>()) {
if (input_node->isa<Parameter>()) {
const auto input_adjoint_iter = ad_param()->anfnode_to_variable_adjoint_.find(input_node);
if (input_adjoint_iter != ad_param()->anfnode_to_variable_adjoint_.end()) {
if (input_adjoint_iter != ad_param()->anfnode_to_variable_adjoint_.end() &&
input_adjoint_iter->second->k_node() != nullptr) {
return input_adjoint_iter->second->k_node();
}
}

View File

@ -29,6 +29,7 @@
namespace mindspore {
namespace pynative {
namespace {
const char kAddedValue[] = "added_value";
const mindspore::HashSet<std::string> kNotRealOP{prim::kPrimMakeTuple->name(),
prim::kPrimTupleGetItem->name(),
prim::kPrimStopGradient->name(),
@ -54,7 +55,6 @@ FrontendOpRunInfoPtr GetOpRunInfo(const py::object &out, const py::args &args, c
MS_EXCEPTION_IF_NULL(added_out_v);
// Forward output of op in ms_function graph
*added_out_v = PyNativeAlgo::DataConvert::PyObjToValue(tuple_out[1]);
MS_LOG(DEBUG) << "Added output value is: " << (*added_out_v)->ToString();
auto op_run_info = std::make_shared<FrontendOpRunInfo>();
PyNativeAlgo::PyParser::ParseOpInputByPythonObj(op_run_info, args);
op_run_info->base_op_run_info.op_name = graph_phase;
@ -237,13 +237,13 @@ void MsFunction::ReplaceWithRealTensorsInGradGraph(const GradExecutor *grad_exec
// to real value.
RunReplace(added_make_tuple, total_output_tensors, grad_graph, is_dynamic_shape);
grad_executor->top_cell()->set_op_info_with_ms_func_forward_tensors(op_run_info->op_info, total_output_tensors);
grad_executor->top_cell()->set_opinfo_with_tensor_id(op_run_info->op_info + kAddedValue, total_output_tensors);
}
void MsFunction::UpdateMsFunctionForwardTensors(const GradExecutor *grad_executor, const TopCellInfoPtr &top_cell,
const string &op_info, const ValuePtr &new_forward_value) const {
MS_EXCEPTION_IF_NULL(new_forward_value);
MS_LOG(DEBUG) << "Ms func graph has already ran before. The graph phase is: " << graph_phase_;
MS_LOG(DEBUG) << "The output values of added forward nodes are: " << new_forward_value->ToString();
std::vector<tensor::TensorPtr> new_tensors;
TensorValueToTensor(new_forward_value, &new_tensors);
if (new_tensors.empty()) {

View File

@ -250,9 +250,9 @@ void TopCellInfo::set_opinfo_with_tensor_id(const std::string &op_info,
<< " in op_info_with_tensor_id map";
}
// Record the relationship between the forward op and its output tensor id
(void)std::for_each(op_out_tensors.begin(), op_out_tensors.end(), [this, &op_info](const tensor::TensorPtr &tensor) {
for (const auto &tensor : op_out_tensors) {
(void)op_info_with_tensor_id_[op_info].emplace_back(tensor->id());
});
}
}
} // namespace pynative
} // namespace mindspore