!49762 add zeroslike placeholer

Merge pull request !49762 from luochao60/Pynative_optimize_zeros_like_20230303
This commit is contained in:
i-robot 2023-03-04 03:42:23 +00:00 committed by Gitee
commit 4aa246ff8c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 35 additions and 10 deletions

View File

@ -80,6 +80,14 @@ ValueNodePtr GetSparseTensorShapeNode(const ShapeVector &shape) {
return value_shape;
}
AnfNodePtr ZerosLike(const FuncGraphPtr &tape, const AbstractBasePtr &abstract) {
MS_EXCEPTION_IF_NULL(abstract);
auto abs = PyNativeAlgo::Common::SetAbstractValueToAnyValue(abstract);
auto output = tape->NewCNode({NewValueNode(prim::kPrimZerosLike), NewValueNode(std::make_shared<tensor::Tensor>(0))});
output->set_abstract(abs);
return output;
}
AnfNodePtr BuildSpecialLikeSparseTensor(const FuncGraphPtr &tape, const ValuePtr &sparse_value,
const AnfNodePtr &dout_value_node) {
MS_EXCEPTION_IF_NULL(tape);
@ -354,9 +362,12 @@ void FunctionNode::ReplaceEdges() {
}
AnfNodePtr VariableAdjoint::RealDout() {
const auto &accumulate_dout = fn()->accumulate_dout();
auto &tape = fn()->tape();
MS_EXCEPTION_IF_NULL(out_value_);
auto &tape = fn()->tape();
if (MS_UNLIKELY(IsZerosLikeNode(fn()->accumulate_dout()))) {
fn()->set_accumulate_dout(BuildZerosLikeNode(fn()->tape(), out_value_));
}
const auto &accumulate_dout = fn()->accumulate_dout();
const auto &dout_abs = accumulate_dout->abstract();
MS_EXCEPTION_IF_NULL(dout_abs);
// For input, if it is a sparsetensor, we need return a sparsetensor.
@ -382,7 +393,7 @@ AutoGradCellImpl::AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std:
TraceGuard trace_guard(std::make_shared<TraceCopy>(cell_inputs[i]->debug_info()));
auto parameter = ad_param()->tape_->add_parameter();
parameter->set_abstract(abs_list[i]);
auto zeros_like_dout = BuildZerosLikeNode(ad_param()->tape_, input_param_values[i]);
auto zeros_like_dout = ZerosLike(ad_param()->tape_, abs_list[i]);
auto func_node = std::make_shared<FunctionNode>(ad_param()->tape_, zeros_like_dout);
auto input_adjoint = std::make_shared<VariableAdjoint>(func_node, input_param_values[i]);
(void)ad_param()->anfnode_to_variable_adjoint_.insert(std::make_pair(cell_inputs[i], input_adjoint));
@ -401,16 +412,19 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
MS_LOG(DEBUG) << "Prim " << prim->name() << " not need do op grad";
return true;
}
bool is_custom_prim =
IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook);
// anfnode_to_variable_adjoint_ hold out value, to avoid device not release, clear its device_address
auto cloned_value = ShallowCopyTensorValue(grad_param->out);
ClearDeviceAddress(cloned_value);
AnfNodePtr dout = BuildSpecialLikeValue(ad_param()->tape_, cloned_value, SpecialType::kZerosLikeType);
// construct zeroslike placeholder, if need use in bprop, we replace it in backprogate.
AnfNodePtr dout = ZerosLike(ad_param()->tape_, grad_param->out->ToAbstract());
auto fn = std::make_shared<FunctionNode>(ad_param()->tape_, dout);
auto variable_adjoint = std::make_shared<VariableAdjoint>(fn, cloned_value);
// Custom forward cnode no need record in bprop graph, because it is a flag cnode for run python. So just create
// bprop_cut grad op is ok
bool is_custom_prim =
IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook);
if (!grad_param->grad_by_value && !is_custom_prim) {
variable_adjoint->set_k_node(BuildKNode(grad_param, true));
need_do_manager_replace_ = true;
@ -550,7 +564,7 @@ CNodePtr AutoGradCellImpl::GetBPropFromFProp(const GradParamPtr &grad_param, con
// Call by tape_
MS_EXCEPTION_IF_NULL(tape_dout);
*tape_dout = BuildZerosLikeNode(ad_param()->tape_, grad_param->out);
*tape_dout = ZerosLike(ad_param()->tape_, grad_param->cnode->abstract());
(void)bprop_builder_inputs.emplace_back(*tape_dout);
(void)bprop_builder_inputs.insert(bprop_builder_inputs.cbegin(), NewValueNode(after_opt_fg));
get_bprop = ad_param()->tape_->NewCNode(bprop_builder_inputs);
@ -667,7 +681,7 @@ void AutoGradCellImpl::GradGraphByExpander(const GradParamPtr &grad_param) {
auto out = pynative::PyNativeAlgo::Common::CreatOutputTensorValueByAbstract(cnode->abstract());
(void)cnode_inputs.emplace_back(k_node);
// Set dout
AnfNodePtr dout = BuildSpecialLikeValue(ad_param()->tape_, out, SpecialType::kZerosLikeType);
AnfNodePtr dout = ZerosLike(ad_param()->tape_, out->ToAbstract());
(void)cnode_inputs.emplace_back(dout);
auto input_node = ad_param()->tape_->NewCNode(cnode_inputs);
input_node->set_abstract(cnode->abstract());
@ -714,7 +728,7 @@ void AutoGradCellImpl::CreateParameterAdjoint(const GradParamPtr &grad_param) {
param->set_default_param(tensor);
}
param->set_abstract(graph_parameters[i]->abstract());
auto zeros_like_dout = BuildZerosLikeNode(ad_param()->tape_, grad_param->op_args[i]);
auto zeros_like_dout = ZerosLike(ad_param()->tape_, graph_parameters[i]->abstract());
auto func_node = std::make_shared<FunctionNode>(ad_param()->tape_, zeros_like_dout);
auto adjoint = std::make_shared<VariableAdjoint>(func_node, grad_param->op_args[i]);
adjoint->set_k_node(param);
@ -1072,7 +1086,7 @@ void AutoGradCellImpl::BuildForwardLastNode() {
void AutoGradCellImpl::AddParameterNode(const AnfNodePtr &parameter, const ValuePtr &tensor) {
MS_EXCEPTION_IF_NULL(parameter);
MS_EXCEPTION_IF_NULL(tensor);
auto zeros_like_dout = BuildZerosLikeNode(ad_param()->tape_, tensor);
auto zeros_like_dout = ZerosLike(ad_param()->tape_, tensor->ToAbstract());
auto func_node = std::make_shared<FunctionNode>(ad_param()->tape_, zeros_like_dout);
auto input_adjoint = std::make_shared<VariableAdjoint>(func_node, tensor);
(void)ad_param()->anfnode_to_variable_adjoint_.insert(std::make_pair(parameter, input_adjoint));
@ -1284,6 +1298,10 @@ void AutoGradCellImpl::BackPropagate() {
has_primc = true;
}
const auto &fn = variable->fn();
// If zeroslike not used in funcgraph, we need replace the zeroslike placeholder with real zeroslike value.
if (MS_UNLIKELY(IsZerosLikeNode(fn->accumulate_dout()))) {
fn->set_accumulate_dout(BuildZerosLikeNode(fn->tape(), variable->out_value()));
}
// Replace real dout to fake dout, update replace result to eliminate tuplegetitem
// when accumulate_dout is tuplegetitem
Replace(fn->fake_dout(), fn->accumulate_dout(), true);

View File

@ -23,6 +23,13 @@
#include "pipeline/pynative/base.h"
#include "pipeline/pynative/pynative_execute.h"
#ifndef MS_UNLIKELY
#ifdef _MSC_VER
#define MS_UNLIKELY(x) (x)
#else
#define MS_UNLIKELY(x) __builtin_expect(!!(x), 0)
#endif
#endif
namespace mindspore {
namespace pynative {
class PyNativeExecutor;