forked from mindspore-Ecosystem/mindspore
!49282 Fix some bug for PyNative
Merge pull request !49282 from zjun/2002
This commit is contained in:
commit
5de427bfed
|
@ -558,6 +558,11 @@ void SessionBasic::GetParameterIndex(const KernelGraph *graph, const std::vector
|
|||
parameter_index->emplace(param, index++);
|
||||
continue;
|
||||
}
|
||||
// Input is scalar. param shape will be [1], input shape will be []
|
||||
if (param_shape.size() == 1 && input_shape.empty()) {
|
||||
parameter_index->emplace(param, index++);
|
||||
continue;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Shape size of input tensor(" << input_shape << ") and parameter(" << param_shape
|
||||
<< ") are different, input index: " << index << ", parameter: " << param->DebugString();
|
||||
}
|
||||
|
|
|
@ -581,17 +581,19 @@ FuncGraphPtr AutoGradCellImpl::GradFuncGraph(const GradParamPtr &grad_param) {
|
|||
|
||||
GradGraphByExpander(grad_param);
|
||||
|
||||
// Set dout parameter
|
||||
if (kMonadPrim.find(GetCNodePrimitive(ad_param()->last_node_)) != kMonadPrim.end()) {
|
||||
ad_param()->last_node_ = common::AnfAlgo::VisitKernelWithReturnType(ad_param()->last_node_, 0, false,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimMakeTuple})
|
||||
.first;
|
||||
}
|
||||
auto ad_graph_dout = ad_param()->tape_->add_parameter();
|
||||
ad_graph_dout->set_abstract(ad_param()->last_node_->abstract());
|
||||
ad_param()->anfnode_to_variable_adjoint_.at(ad_param()->last_node_)->fn()->UpdateAccumulativeDout(ad_graph_dout);
|
||||
if (ad_param()->last_node_ != nullptr) {
|
||||
// Set dout parameter
|
||||
if (kMonadPrim.find(GetCNodePrimitive(ad_param()->last_node_)) != kMonadPrim.end()) {
|
||||
ad_param()->last_node_ = common::AnfAlgo::VisitKernelWithReturnType(
|
||||
ad_param()->last_node_, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple})
|
||||
.first;
|
||||
}
|
||||
auto ad_graph_dout = ad_param()->tape_->add_parameter();
|
||||
ad_graph_dout->set_abstract(ad_param()->last_node_->abstract());
|
||||
ad_param()->anfnode_to_variable_adjoint_.at(ad_param()->last_node_)->fn()->UpdateAccumulativeDout(ad_graph_dout);
|
||||
|
||||
(void)BackPropagate();
|
||||
(void)BackPropagate();
|
||||
}
|
||||
|
||||
AnfNodePtrList outputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
abstract::AbstractBasePtrList out_abs_list;
|
||||
|
@ -655,7 +657,9 @@ void AutoGradCellImpl::GradGraphByExpander(const GradParamPtr &grad_param) {
|
|||
k_node->set_abstract(cnode->abstract());
|
||||
// In ms function, copy forward graph cnode info to bprop graph
|
||||
if (ms_function_by_value && cnode->forward().first != nullptr) {
|
||||
k_node->set_forward(cnode->forward().first, cnode->forward().second);
|
||||
auto new_v_node = NewValueNode(cnode->forward().first->value());
|
||||
new_v_node->set_abstract(cnode->forward().first->abstract());
|
||||
k_node->set_forward(new_v_node, cnode->forward().second);
|
||||
ad_param()->tape_->set_used_forward_nodes({k_node});
|
||||
}
|
||||
MS_LOG(DEBUG) << "Build knode " << k_node->DebugString();
|
||||
|
@ -856,14 +860,15 @@ AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const AnfNodePtr &input_nod
|
|||
} else if (IsPrimitiveCNode(input_node, prim::kPrimTupleGetItem)) {
|
||||
return BuildKNodeForTupleGetItem(input_node);
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Cannot find input in adjoint map, inp: " << input_node->DebugString();
|
||||
MS_LOG(EXCEPTION) << "Can not find input in adjoint map, inp: " << input_node->DebugString();
|
||||
}
|
||||
return input_adjoint_iter->second->k_node();
|
||||
} else {
|
||||
// Tuple sens will come in
|
||||
if (input_node->isa<Parameter>() && input_node->abstract()->isa<abstract::AbstractSequence>()) {
|
||||
if (input_node->isa<Parameter>()) {
|
||||
const auto input_adjoint_iter = ad_param()->anfnode_to_variable_adjoint_.find(input_node);
|
||||
if (input_adjoint_iter != ad_param()->anfnode_to_variable_adjoint_.end()) {
|
||||
if (input_adjoint_iter != ad_param()->anfnode_to_variable_adjoint_.end() &&
|
||||
input_adjoint_iter->second->k_node() != nullptr) {
|
||||
return input_adjoint_iter->second->k_node();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
namespace mindspore {
|
||||
namespace pynative {
|
||||
namespace {
|
||||
const char kAddedValue[] = "added_value";
|
||||
const mindspore::HashSet<std::string> kNotRealOP{prim::kPrimMakeTuple->name(),
|
||||
prim::kPrimTupleGetItem->name(),
|
||||
prim::kPrimStopGradient->name(),
|
||||
|
@ -54,7 +55,6 @@ FrontendOpRunInfoPtr GetOpRunInfo(const py::object &out, const py::args &args, c
|
|||
MS_EXCEPTION_IF_NULL(added_out_v);
|
||||
// Forward output of op in ms_function graph
|
||||
*added_out_v = PyNativeAlgo::DataConvert::PyObjToValue(tuple_out[1]);
|
||||
MS_LOG(DEBUG) << "Added output value is: " << (*added_out_v)->ToString();
|
||||
auto op_run_info = std::make_shared<FrontendOpRunInfo>();
|
||||
PyNativeAlgo::PyParser::ParseOpInputByPythonObj(op_run_info, args);
|
||||
op_run_info->base_op_run_info.op_name = graph_phase;
|
||||
|
@ -237,13 +237,13 @@ void MsFunction::ReplaceWithRealTensorsInGradGraph(const GradExecutor *grad_exec
|
|||
// to real value.
|
||||
RunReplace(added_make_tuple, total_output_tensors, grad_graph, is_dynamic_shape);
|
||||
grad_executor->top_cell()->set_op_info_with_ms_func_forward_tensors(op_run_info->op_info, total_output_tensors);
|
||||
grad_executor->top_cell()->set_opinfo_with_tensor_id(op_run_info->op_info + kAddedValue, total_output_tensors);
|
||||
}
|
||||
|
||||
void MsFunction::UpdateMsFunctionForwardTensors(const GradExecutor *grad_executor, const TopCellInfoPtr &top_cell,
|
||||
const string &op_info, const ValuePtr &new_forward_value) const {
|
||||
MS_EXCEPTION_IF_NULL(new_forward_value);
|
||||
MS_LOG(DEBUG) << "Ms func graph has already ran before. The graph phase is: " << graph_phase_;
|
||||
MS_LOG(DEBUG) << "The output values of added forward nodes are: " << new_forward_value->ToString();
|
||||
std::vector<tensor::TensorPtr> new_tensors;
|
||||
TensorValueToTensor(new_forward_value, &new_tensors);
|
||||
if (new_tensors.empty()) {
|
||||
|
|
|
@ -250,9 +250,9 @@ void TopCellInfo::set_opinfo_with_tensor_id(const std::string &op_info,
|
|||
<< " in op_info_with_tensor_id map";
|
||||
}
|
||||
// Record the relationship between the forward op and its output tensor id
|
||||
(void)std::for_each(op_out_tensors.begin(), op_out_tensors.end(), [this, &op_info](const tensor::TensorPtr &tensor) {
|
||||
for (const auto &tensor : op_out_tensors) {
|
||||
(void)op_info_with_tensor_id_[op_info].emplace_back(tensor->id());
|
||||
});
|
||||
}
|
||||
}
|
||||
} // namespace pynative
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue