forked from mindspore-Ecosystem/mindspore
!46740 Fix pins for PyNative
Merge pull request !46740 from zjun/fix_xxx_alpha
This commit is contained in:
commit
60a89a9fc8
|
@ -61,7 +61,6 @@ struct FrontendOpRunInfo {
|
|||
int mix_type{0};
|
||||
size_t op_index = 0;
|
||||
size_t input_size = 0;
|
||||
size_t custom_bprop_cell_count = 0;
|
||||
PrimitivePyPtr op_prim{nullptr};
|
||||
ValuePtr out_value{nullptr};
|
||||
std::string op_info;
|
||||
|
@ -92,7 +91,7 @@ struct InputArgsInfo {
|
|||
bool has_sens{false};
|
||||
bool grad_is_running{false};
|
||||
bool use_dynamic_shape_process{false};
|
||||
PrimitivePyPtr custom_bprp_prim{nullptr};
|
||||
PrimitivePyPtr custom_bprop_prim{nullptr};
|
||||
ValuePtr out_value{nullptr};
|
||||
std::string cell_id;
|
||||
std::string already_run_cell_id;
|
||||
|
|
|
@ -266,7 +266,6 @@ ValuePtr CastOperation::DoAutoCast(const FrontendOpRunInfoPtr &op_run_info, cons
|
|||
constexpr auto input_size = 2;
|
||||
const auto &cast_run_info = std::make_shared<FrontendOpRunInfo>();
|
||||
cast_run_info->grad_flag = op_run_info->grad_flag;
|
||||
cast_run_info->custom_bprop_cell_count = op_run_info->custom_bprop_cell_count;
|
||||
MS_EXCEPTION_IF_NULL(cast_prim_);
|
||||
cast_run_info->op_prim = cast_prim_;
|
||||
cast_run_info->base_op_run_info.op_name = prim::kPrimCast->name();
|
||||
|
|
|
@ -180,10 +180,8 @@ void ForwardExecutor::RunOpForward(const FrontendOpRunInfoPtr &op_run_info) {
|
|||
return;
|
||||
}
|
||||
// 4. Do op grad and record op info
|
||||
// If cell have custom bprop, no need do op grad. Otherwise, need do.
|
||||
// If ms function is complime, real run op must be not record
|
||||
bool is_custom_bprop = op_run_info->custom_bprop_cell_count <= 0;
|
||||
if (!is_ms_function_compiling_ && is_custom_bprop) {
|
||||
// If ms function is compile, op info will not be find in second training step
|
||||
if (!is_ms_function_compiling_ && grad()->custom_bprop_cell_count() <= 0) {
|
||||
grad()->ProcessOpGradInfo(op_run_info);
|
||||
}
|
||||
}
|
||||
|
@ -195,7 +193,6 @@ FrontendOpRunInfoPtr ForwardExecutor::GenerateOpRunInfo(const py::args &args) co
|
|||
const auto &op_run_info = std::make_shared<FrontendOpRunInfo>();
|
||||
// Used for async run
|
||||
op_run_info->grad_flag = grad()->grad_flag();
|
||||
op_run_info->custom_bprop_cell_count = grad()->custom_bprop_cell_count();
|
||||
op_run_info->base_op_run_info.op_name = args[static_cast<size_t>(RunOpArgsEnum::PY_NAME)].cast<std::string>();
|
||||
op_run_info->base_op_run_info.use_dynamic_shape_process =
|
||||
grad()->use_dynamic_shape_process() &&
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -43,6 +43,7 @@ const std::map<SpecialType, std::shared_ptr<Primitive>> kValueType{{SpecialType:
|
|||
const std::vector<PrimitivePtr> kGradBlackList{
|
||||
prim::kPrimMakeTuple, prim::kPrimTupleGetItem, prim::kPrimStopGradient, prim::kPrimUpdateState,
|
||||
prim::kPrimNPUAllocFloatStatus, prim::kPrimNPUGetFloatStatus, prim::kPrimNPUClearFloatStatus};
|
||||
|
||||
void ClearDeviceAddress(const ValuePtr &value) {
|
||||
std::vector<tensor::TensorPtr> tensors;
|
||||
TensorValueToTensor(value, &tensors);
|
||||
|
@ -169,7 +170,7 @@ AnfNodePtr BuildSpecialLikeValue(const FuncGraphPtr &tape, const ValuePtr &value
|
|||
special_like_value->set_abstract(value->ToAbstract()->Broaden());
|
||||
return special_like_value;
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For value" << value->type()->ToString() << "`, the type is not tensor or sequence";
|
||||
MS_EXCEPTION(TypeError) << "For value" << value->ToString() << ", the type is not tensor or sequence";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -280,9 +281,10 @@ AnfNodePtr FunctionNode::HyperAdd(const AnfNodePtr &left_node, const AnfNodePtr
|
|||
}
|
||||
}
|
||||
|
||||
void FunctionNode::AddEdge(const AnfNodePtr &next_node, const AnfNodePtr &din) {
|
||||
void FunctionNode::AddNextEdge(const AnfNodePtr &next_node, const AnfNodePtr &din) {
|
||||
MS_EXCEPTION_IF_NULL(next_node);
|
||||
MS_EXCEPTION_IF_NULL(din);
|
||||
// next_node and its corresponding din
|
||||
(void)next_edges_.emplace_back(std::make_pair(next_node, din));
|
||||
if (din == fake_dout_) {
|
||||
(void)need_replace_edges_.emplace_back(next_edges_.size() - 1);
|
||||
|
@ -338,6 +340,7 @@ AutoGradCellImpl::AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std:
|
|||
bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
|
||||
MS_EXCEPTION_IF_NULL(grad_param);
|
||||
|
||||
MS_LOG(DEBUG) << "Forward cnode: " << grad_param->cnode->DebugString();
|
||||
auto prim = GetCNodePrimitive(grad_param->cnode);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Should be primitive, but: " << grad_param->cnode->DebugString();
|
||||
|
@ -346,44 +349,44 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
|
|||
MS_LOG(DEBUG) << "Prim " << prim->name() << " not need do op grad";
|
||||
return true;
|
||||
}
|
||||
bool is_custom_prim =
|
||||
IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook);
|
||||
// anfnode_to_variable_adjoint_ hold out value, to avoid device not release, clear its device_address
|
||||
auto cloned_value = ShallowCopyTensorValue(grad_param->out);
|
||||
ClearDeviceAddress(cloned_value);
|
||||
AnfNodePtr dout = BuildSpecialLikeValue(tape_, cloned_value, SpecialType::kZerosLikeType);
|
||||
auto fn = std::make_shared<FunctionNode>(tape_, dout);
|
||||
auto variable_adjoint = std::make_shared<VariableAdjoint>(fn, cloned_value);
|
||||
if (!grad_param->grad_by_value) {
|
||||
BuildKNode(grad_param, variable_adjoint);
|
||||
// Custom forward cnode no need record in bprop graph, because it is a flag cnode for run python. So just create
|
||||
// bprop_cut grad op is ok
|
||||
if (!grad_param->grad_by_value && !is_custom_prim) {
|
||||
variable_adjoint->set_k_node(BuildKNode(grad_param));
|
||||
need_do_manager_replace_ = true;
|
||||
}
|
||||
CNodePtr input_node = ConstructBpropGraphInput(grad_param, dout, variable_adjoint);
|
||||
CNodePtr input_node = ConstructBpropGraphInput(grad_param, dout, variable_adjoint, is_custom_prim);
|
||||
MS_LOG(DEBUG) << "Construct input cnode: " << input_node->DebugString();
|
||||
// Gradient outputs
|
||||
std::vector<CNodePtr> outputs;
|
||||
#ifndef ENABLE_TEST
|
||||
if (IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook)) {
|
||||
BuildBPropCutCNode(input_node, &outputs);
|
||||
if (is_custom_prim) {
|
||||
BuildBPropCutCNode(input_node, prim, &outputs);
|
||||
} else {
|
||||
#ifndef ENABLE_TEST
|
||||
mindspore::BuildBprop(input_node, &outputs, &users_);
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(DEBUG) << "Expander has no bprop of this prim: " << grad_param->cnode->DebugString();
|
||||
BuildCustomBpropCNode(input_node, &outputs);
|
||||
BuildCustomBpropCNode(input_node, prim, &outputs);
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook)) {
|
||||
BuildBPropCutCNode(input_node, &outputs);
|
||||
} else {
|
||||
BuildCustomBpropCNode(input_node, &outputs);
|
||||
}
|
||||
BuildCustomBpropCNode(input_node, prim, &outputs);
|
||||
#endif
|
||||
if (!outputs.empty()) {
|
||||
UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args);
|
||||
} else {
|
||||
}
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(DEBUG) << "This op has not custom bprop: " << grad_param->cnode->DebugString();
|
||||
BuildFakeBpropCNode(input_node, &outputs);
|
||||
variable_adjoint->set_is_fake_bprop(true);
|
||||
variable_adjoint->set_fake_prim_name(prim->name());
|
||||
}
|
||||
UpdateNextEdges(variable_adjoint, grad_param->cnode, outputs, grad_param->op_args);
|
||||
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
|
||||
// record last_node for brackpropagate
|
||||
last_node_ = grad_param->cnode;
|
||||
|
@ -395,6 +398,7 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
|
|||
|
||||
AnfNodePtrList args_node_list;
|
||||
CNodePtr bprop_cnode = nullptr;
|
||||
AnfNodePtr k_node = nullptr;
|
||||
AnfNodePtr dout = nullptr;
|
||||
if (grad_param->grad_by_value) {
|
||||
for (size_t i = 0; i < grad_param->op_args.size(); ++i) {
|
||||
|
@ -412,14 +416,13 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
|
|||
}
|
||||
bprop_cnode = GetBPropFromFProp(grad_param->fprop_fg, args_node_list, grad_param->out, &dout);
|
||||
} else {
|
||||
k_node = BuildKNode(grad_param);
|
||||
BuildKNodeListFromPrimalCNode(grad_param->cnode, grad_param->op_args, &args_node_list);
|
||||
bprop_cnode = GetBPropFromFProp(grad_param->fprop_fg, args_node_list, grad_param->out, &dout);
|
||||
}
|
||||
auto fn = std::make_shared<FunctionNode>(tape_, dout);
|
||||
auto variable_adjoint = std::make_shared<VariableAdjoint>(fn, grad_param->out);
|
||||
if (!grad_param->grad_by_value) {
|
||||
BuildKNode(grad_param, variable_adjoint);
|
||||
}
|
||||
variable_adjoint->set_k_node(k_node);
|
||||
std::vector<CNodePtr> outputs;
|
||||
for (size_t i = 1; i < grad_param->cnode->size(); ++i) {
|
||||
// bprop_app[0] env
|
||||
|
@ -427,7 +430,7 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
|
|||
din->set_abstract(grad_param->op_args[i - 1]->ToAbstract()->Broaden());
|
||||
(void)outputs.emplace_back(din);
|
||||
}
|
||||
UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args);
|
||||
UpdateNextEdges(variable_adjoint, grad_param->cnode, outputs, grad_param->op_args);
|
||||
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
|
||||
need_do_manager_replace_ = true;
|
||||
return true;
|
||||
|
@ -492,7 +495,6 @@ FuncGraphPtr AutoGradCellImpl::Finish(const AnfNodePtrList &weights, const std::
|
|||
if (!last_node_->isa<ValueNode>() && !last_node_->isa<Parameter>()) {
|
||||
(void)BackPropagate();
|
||||
}
|
||||
|
||||
SetOutput(weights, grad_position, grad_attr);
|
||||
// Replace Parameter of primal funcgraph with parameter of tape_;
|
||||
ReplacePrimalParameter(weights, grad_attr.has_sens);
|
||||
|
@ -501,12 +503,12 @@ FuncGraphPtr AutoGradCellImpl::Finish(const AnfNodePtrList &weights, const std::
|
|||
}
|
||||
|
||||
CNodePtr AutoGradCellImpl::ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout,
|
||||
const VariableAdjointPtr &variable_adjoint) {
|
||||
const VariableAdjointPtr &variable_adjoint, bool is_custom_prim) {
|
||||
MS_EXCEPTION_IF_NULL(grad_param);
|
||||
std::vector<AnfNodePtr> node_list;
|
||||
(void)node_list.emplace_back(grad_param->cnode->input(0));
|
||||
auto out_abs = grad_param->out->ToAbstract()->Broaden();
|
||||
if (grad_param->grad_by_value) {
|
||||
if (grad_param->grad_by_value || is_custom_prim) {
|
||||
for (size_t i = 0; i < grad_param->op_args.size(); ++i) {
|
||||
const auto &v = grad_param->op_args[i];
|
||||
auto node = grad_param->cnode->input(i + 1);
|
||||
|
@ -541,7 +543,6 @@ void AutoGradCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, cons
|
|||
std::vector<AnfNodePtr> *const node_list) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
MS_LOG(DEBUG) << "Get knode for node " << cnode->input(i)->DebugString();
|
||||
if (cnode->input(i)->isa<CNode>()) {
|
||||
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(cnode->input(i));
|
||||
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
|
||||
|
@ -553,10 +554,11 @@ void AutoGradCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, cons
|
|||
cnode->input(i)->set_abstract(op_args[i - 1]->ToAbstract()->Broaden());
|
||||
(void)node_list->emplace_back(cnode->input(i));
|
||||
}
|
||||
MS_LOG(DEBUG) << "Get knode for node " << cnode->input(i)->DebugString();
|
||||
}
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param, const VariableAdjointPtr &variable_adjoint) {
|
||||
AnfNodePtr AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param) {
|
||||
MS_EXCEPTION_IF_NULL(grad_param);
|
||||
AnfNodePtrList node_list;
|
||||
for (size_t i = 0; i < grad_param->cnode->inputs().size(); ++i) {
|
||||
|
@ -564,8 +566,8 @@ void AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param, const Variable
|
|||
}
|
||||
auto k_node = tape_->NewCNode(node_list);
|
||||
k_node->set_abstract(grad_param->out->ToAbstract()->Broaden());
|
||||
variable_adjoint->set_k_node(k_node);
|
||||
MS_LOG(DEBUG) << "Build knode " << k_node->DebugString();
|
||||
return k_node;
|
||||
}
|
||||
|
||||
AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const AnfNodePtr &input_node) {
|
||||
|
@ -573,6 +575,11 @@ AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const AnfNodePtr &input_nod
|
|||
if (input_node->isa<CNode>()) {
|
||||
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(input_node);
|
||||
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
|
||||
if (IsPrimitiveCNode(input_node, prim::kPrimMakeTuple)) {
|
||||
return BuildKNodeForMakeTuple(input_node);
|
||||
} else if (IsPrimitiveCNode(input_node, prim::kPrimTupleGetItem)) {
|
||||
return BuildKNodeForTupleGetItem(input_node);
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Cannot find input in adjoint map, inp: " << input_node->DebugString();
|
||||
}
|
||||
return input_adjoint_iter->second->k_node();
|
||||
|
@ -581,22 +588,70 @@ AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const AnfNodePtr &input_nod
|
|||
}
|
||||
}
|
||||
|
||||
AnfNodePtr AutoGradCellImpl::BuildKNodeForMakeTuple(const AnfNodePtr &input_node) {
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
MS_LOG(DEBUG) << "Build knode for MakeTuple " << input_node->DebugString();
|
||||
const auto &cnode = input_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
ValuePtrList op_args;
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
(void)inputs.emplace_back(BuildKNodeForCNodeInput(cnode->input(i)));
|
||||
if (cnode->input(i)->isa<CNode>() || cnode->input(i)->isa<Parameter>()) {
|
||||
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(cnode->input(i));
|
||||
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find input in adjoint map, inp: " << cnode->input(i)->DebugString();
|
||||
}
|
||||
(void)op_args.emplace_back(input_adjoint_iter->second->out_value());
|
||||
} else {
|
||||
auto value_node = cnode->input(i)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
(void)op_args.emplace_back(value_node->value());
|
||||
}
|
||||
}
|
||||
auto out_value = MakeValue(op_args);
|
||||
AnfNodePtr dout = BuildSpecialLikeValue(tape_, out_value, SpecialType::kZerosLikeType);
|
||||
auto fn = std::make_shared<FunctionNode>(tape_, dout);
|
||||
auto variable_adjoint = std::make_shared<VariableAdjoint>(fn, out_value);
|
||||
auto k_node = tape_->NewCNode(inputs);
|
||||
k_node->set_abstract(input_node->abstract());
|
||||
variable_adjoint->set_k_node(k_node);
|
||||
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(input_node, variable_adjoint));
|
||||
return k_node;
|
||||
}
|
||||
|
||||
AnfNodePtr AutoGradCellImpl::BuildKNodeForTupleGetItem(const AnfNodePtr &input_node) {
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
MS_LOG(DEBUG) << "Build knode for TupleGetItem " << input_node->DebugString();
|
||||
const auto &tuple_item_cnode = input_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_item_cnode);
|
||||
const auto &make_tuple_cnode = tuple_item_cnode->input(1)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple_cnode);
|
||||
auto index_value = GetValueNode<Int64ImmPtr>(tuple_item_cnode->input(2));
|
||||
auto real_node = make_tuple_cnode->input(LongToSize(index_value->value()));
|
||||
return BuildKNodeForCNodeInput(real_node);
|
||||
}
|
||||
|
||||
bool GradPynativeOp(const AutoGradCellImplPtr &k_cell, const GradParamPtr &grad_param) {
|
||||
return k_cell->KPynativeOp(grad_param);
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr &cnode,
|
||||
void AutoGradCellImpl::UpdateNextEdges(const VariableAdjointPtr &variable, const CNodePtr &cnode,
|
||||
const std::vector<CNodePtr> &dins, const ValuePtrList &op_args) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (dins.size() != op_args.size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of dins is not same as op_args";
|
||||
MS_LOG(EXCEPTION) << "The size of dins is not same as op_args, cnode: " << cnode->DebugString();
|
||||
}
|
||||
const auto &fn = variable->fn();
|
||||
for (size_t i = 0; i < op_args.size(); ++i) {
|
||||
const auto &node = cnode->input(i + 1);
|
||||
const auto &din = dins[i];
|
||||
MS_LOG(DEBUG) << "Node " << node->DebugString() << ", din " << din->DebugString();
|
||||
UpdateNextEdge(fn, node, din, op_args[i]);
|
||||
}
|
||||
if (fn->next_edges().empty()) {
|
||||
variable->set_is_need_grad(false);
|
||||
}
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::UpdateNextEdge(const FunctionNodePtr &fn, const AnfNodePtr &node, const AnfNodePtr &din,
|
||||
|
@ -604,12 +659,16 @@ void AutoGradCellImpl::UpdateNextEdge(const FunctionNodePtr &fn, const AnfNodePt
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(din);
|
||||
MS_EXCEPTION_IF_NULL(op_arg);
|
||||
if (anfnode_to_variable_adjoint_.find(node) != anfnode_to_variable_adjoint_.end()) {
|
||||
auto variable = anfnode_to_variable_adjoint_.at(node);
|
||||
auto real_din = GetRealDin(fn, variable->out_value(), op_arg, din);
|
||||
fn->AddEdge(node, real_din);
|
||||
const auto it = anfnode_to_variable_adjoint_.find(node);
|
||||
if (it != anfnode_to_variable_adjoint_.end()) {
|
||||
if (!it->second->is_need_grad()) {
|
||||
return;
|
||||
}
|
||||
auto real_din = GetRealDin(fn, it->second->out_value(), op_arg, din);
|
||||
fn->AddNextEdge(node, real_din);
|
||||
} else if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || IsPrimitiveCNode(cnode, prim::kPrimUpdateState)) {
|
||||
return;
|
||||
}
|
||||
|
@ -624,8 +683,10 @@ void AutoGradCellImpl::UpdateNextEdge(const FunctionNodePtr &fn, const AnfNodePt
|
|||
CNodePtr new_din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), din, NewValueNode(SizeToLong(i))});
|
||||
new_din->set_abstract(sub_value->ToAbstract()->Broaden());
|
||||
if (din == fn->fake_dout()) {
|
||||
// The new_din's index input is fn->fake_dout()
|
||||
AddUser(fn->fake_dout(), new_din, 1);
|
||||
}
|
||||
// Add next edge to fn
|
||||
UpdateNextEdge(fn, input_node, new_din, sub_value);
|
||||
}
|
||||
} else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
|
||||
|
@ -636,7 +697,7 @@ void AutoGradCellImpl::UpdateNextEdge(const FunctionNodePtr &fn, const AnfNodePt
|
|||
}
|
||||
UpdateNextEdge(fn, src_node, din, op_arg);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "cnode should be tuplegetitem or maketuple " << cnode->DebugString();
|
||||
MS_LOG(EXCEPTION) << "Cnode should be tuplegetitem or maketuple " << cnode->DebugString();
|
||||
}
|
||||
} else if (node->isa<Parameter>()) {
|
||||
auto param = node->cast<ParameterPtr>();
|
||||
|
@ -685,11 +746,14 @@ AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePt
|
|||
MS_EXCEPTION_IF_NULL(din);
|
||||
const auto &out_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(out_value);
|
||||
const auto &sub_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(sub_value);
|
||||
// The node corresponding output tensor is the same as the currently used tensor
|
||||
if (out_value_id == sub_value_id) {
|
||||
return din;
|
||||
} else if (out_value->isa<tensor::Tensor>()) {
|
||||
// out_value is be used, may be it is one of multiple output
|
||||
return BuildZerosLikeNode(tape_, out_value);
|
||||
} else if (out_value->isa<ValueSequence>()) {
|
||||
// The corresponding output of node is ValueSequence, but used one of it
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
if (out_value->isa<ValueTuple>()) {
|
||||
(void)inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
|
@ -699,6 +763,8 @@ AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePt
|
|||
auto value_seq = out_value->cast<ValueSequencePtr>();
|
||||
int index = -1;
|
||||
for (const auto &value : value_seq->value()) {
|
||||
// Find the value's din, if value equal to sub_value, means value be used, is it will get din; Otherwise value's
|
||||
// din is zero , which set by second branch condition above
|
||||
auto real_din = GetRealDin(fn, value, sub_value, din);
|
||||
(void)inputs.emplace_back(real_din);
|
||||
|
||||
|
@ -717,12 +783,9 @@ AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePt
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::BuildBPropCutCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs) {
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString();
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::BuildBPropCutCNode(const CNodePtr &cnode, const PrimitivePtr &prim,
|
||||
std::vector<CNodePtr> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_py = prim->cast<PrimitivePyPtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_py);
|
||||
auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut");
|
||||
|
@ -738,29 +801,35 @@ void AutoGradCellImpl::BuildBPropCutCNode(const CNodePtr &cnode, std::vector<CNo
|
|||
if (prim->HasAttr("custom_op_bprop")) {
|
||||
(void)bprop_cut->AddAttr("custom_op_bprop", MakeValue(true));
|
||||
}
|
||||
|
||||
// Create gradient outputs cnode
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(bprop_cut)};
|
||||
auto output = tape_->NewCNode(inputs);
|
||||
AbstractBasePtrList abs;
|
||||
size_t args_size = cnode->size() - 2;
|
||||
// Get input, get output, get dout
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
(void)inputs.emplace_back(cnode->input(i));
|
||||
}
|
||||
auto bprop_cut_cnode = tape_->NewCNode(inputs);
|
||||
|
||||
size_t input_num = cnode->size() - 2;
|
||||
AbstractBasePtrList abs_list;
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
output->add_input(cnode->input(i));
|
||||
AddUser(cnode->input(i), output, i);
|
||||
if (i < args_size) {
|
||||
auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), output, NewValueNode(SizeToLong(i - 1))});
|
||||
din->set_abstract(cnode->input(i)->abstract()->Broaden());
|
||||
// bprop_cut_cnode ith input used cnode->input(i)
|
||||
AddUser(cnode->input(i), bprop_cut_cnode, i);
|
||||
if (i < input_num) {
|
||||
auto din = tape_->NewCNode(
|
||||
{NewValueNode(prim::kPrimTupleGetItem), bprop_cut_cnode, NewValueNode(static_cast<int64_t>(i - 1))});
|
||||
MS_EXCEPTION_IF_NULL(cnode->input(i)->abstract());
|
||||
din->set_abstract(cnode->input(i)->abstract());
|
||||
abs_list.emplace_back(cnode->input(i)->abstract());
|
||||
(void)outputs->emplace_back(din);
|
||||
(void)abs.emplace_back(din->abstract());
|
||||
}
|
||||
}
|
||||
output->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
|
||||
bprop_cut_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs) {
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString();
|
||||
}
|
||||
void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, const PrimitivePtr &prim,
|
||||
std::vector<CNodePtr> *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
MS_LOG(DEBUG) << "Build custom bprop: " << prim->name();
|
||||
auto prim_py = prim->cast<PrimitivePyPtr>();
|
||||
{
|
||||
py::gil_scoped_acquire gil;
|
||||
|
@ -780,7 +849,23 @@ void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<
|
|||
prim_py->AddBackwardHookFn(0, fn);
|
||||
prim_py->AddAttr("custom_op_bprop", MakeValue(True));
|
||||
}
|
||||
BuildBPropCutCNode(cnode, outputs);
|
||||
BuildBPropCutCNode(cnode, prim, outputs);
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::BuildFakeBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs) {
|
||||
auto prim = GetCNodePrimitive(cnode);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString();
|
||||
}
|
||||
size_t dout_index = cnode->size() - 1;
|
||||
const auto &dout = cnode->input(dout_index);
|
||||
const auto &dout_cnode = dout->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dout_cnode);
|
||||
// Size is same as op_arg size
|
||||
size_t input_size = cnode->size() - 2;
|
||||
for (size_t i = 1; i < input_size; ++i) {
|
||||
(void)outputs->emplace_back(dout_cnode);
|
||||
}
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg) {
|
||||
|
@ -840,13 +925,12 @@ void AutoGradCellImpl::BackPropagate() {
|
|||
for (auto iter = last_node_reverse_iter; iter != anfnode_to_variable_adjoint_.rend(); ++iter) {
|
||||
MS_LOG(DEBUG) << "BackPropagate cnode: " << iter->first->DebugString();
|
||||
const auto &variable = iter->second;
|
||||
if (!variable->is_need_propagate()) {
|
||||
if (!variable->is_need_propagate() || !variable->is_need_grad()) {
|
||||
MS_LOG(DEBUG) << "No need grad";
|
||||
continue;
|
||||
}
|
||||
if (variable->is_fake_bprop()) {
|
||||
MS_LOG(WARNING) << variable->fake_prim_name() << " op has not corresponding bprop!";
|
||||
continue;
|
||||
MS_LOG(EXCEPTION) << variable->fake_prim_name() << " op has not corresponding bprop!";
|
||||
}
|
||||
if (!has_primc && iter->first->isa<CNode>() && GetCNodePrimitive(iter->first) != nullptr) {
|
||||
has_primc = true;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -64,7 +64,7 @@ class FunctionNode {
|
|||
public:
|
||||
FunctionNode(const FuncGraphPtr &tape, const AnfNodePtr &dout)
|
||||
: tape_(tape), accumulate_dout_(dout), fake_dout_(dout) {}
|
||||
void AddEdge(const AnfNodePtr &next_node, const AnfNodePtr &din);
|
||||
void AddNextEdge(const AnfNodePtr &next_node, const AnfNodePtr &din);
|
||||
void UpdateAccumulativeDout(const AnfNodePtr &new_dout);
|
||||
const std::vector<std::pair<AnfNodePtr, AnfNodePtr>> &next_edges() const { return next_edges_; }
|
||||
const FuncGraphPtr tape() { return tape_; }
|
||||
|
@ -91,6 +91,7 @@ using FunctionNodePtr = std::shared_ptr<FunctionNode>;
|
|||
// Variable represent a parameter or output of a middle cnode
|
||||
class VariableAdjoint {
|
||||
public:
|
||||
VariableAdjoint() = default;
|
||||
VariableAdjoint(const FunctionNodePtr &fn, const ValuePtr &out_value) : fn_(fn), out_value_(out_value) {}
|
||||
|
||||
ValuePtr out_value() const { return out_value_; }
|
||||
|
@ -101,20 +102,24 @@ class VariableAdjoint {
|
|||
void set_is_fake_bprop(bool is_fake_bprop) { is_fake_bprop_ = is_fake_bprop; }
|
||||
bool is_need_propagate() const { return is_need_propagate_; }
|
||||
void set_is_need_propagate(bool is_need_grad) { is_need_propagate_ = is_need_grad; }
|
||||
bool is_need_grad() const { return is_need_grad_; }
|
||||
void set_is_need_grad(bool is_need_grad) { is_need_grad_ = is_need_grad; }
|
||||
AnfNodePtr k_node() const { return k_node_; }
|
||||
void set_k_node(const AnfNodePtr &k_node) { k_node_ = k_node; }
|
||||
AnfNodePtr RealDout();
|
||||
|
||||
private:
|
||||
// Abstract bprop function
|
||||
FunctionNodePtr fn_;
|
||||
ValuePtr out_value_;
|
||||
FunctionNodePtr fn_{nullptr};
|
||||
ValuePtr out_value_{nullptr};
|
||||
// If node has not bprop, we record its prim name
|
||||
std::string fake_prim_name_;
|
||||
// Record this node is a fake bprop
|
||||
bool is_fake_bprop_{false};
|
||||
// Flag to judge need to propagrate
|
||||
bool is_need_propagate_{false};
|
||||
// Flag to judge variable whether need grad
|
||||
bool is_need_grad_{true};
|
||||
// K mapped cnode for primal CNode; primal CNode is owned by primal funcgraph, this is owned by tape_;
|
||||
AnfNodePtr k_node_{nullptr};
|
||||
};
|
||||
|
@ -161,9 +166,9 @@ class AutoGradCellImpl {
|
|||
|
||||
// construct input as cnode for expander
|
||||
CNodePtr ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout,
|
||||
const VariableAdjointPtr &variable_adjoint);
|
||||
const VariableAdjointPtr &variable_adjoint, bool is_custom_prim);
|
||||
// Back propagate for one node;
|
||||
void UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr &cnode, const std::vector<CNodePtr> &dins,
|
||||
void UpdateNextEdges(const VariableAdjointPtr &variable, const CNodePtr &cnode, const std::vector<CNodePtr> &dins,
|
||||
const ValuePtrList &op_args);
|
||||
void UpdateNextEdge(const FunctionNodePtr &fn, const AnfNodePtr &node, const AnfNodePtr &din, const ValuePtr &op_arg);
|
||||
|
||||
|
@ -172,8 +177,8 @@ class AutoGradCellImpl {
|
|||
void AddParameterNode(const AnfNodePtr ¶meter, const ValuePtr &tensor);
|
||||
AnfNodePtr GetRealDin(const FunctionNodePtr &fn, const ValuePtr &out_value, const ValuePtr &sub_value,
|
||||
const AnfNodePtr &din);
|
||||
void BuildBPropCutCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs);
|
||||
void BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs);
|
||||
void BuildBPropCutCNode(const CNodePtr &cnode, const PrimitivePtr &prim, std::vector<CNodePtr> *outputs);
|
||||
void BuildCustomBpropCNode(const CNodePtr &cnode, const PrimitivePtr &prim, std::vector<CNodePtr> *outputs);
|
||||
void BuildFakeBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs);
|
||||
// Replace input or weights parameter from primal funcgraph to parameters of tape_;
|
||||
void ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg);
|
||||
|
@ -195,10 +200,12 @@ class AutoGradCellImpl {
|
|||
void ElimateTupleGetItem();
|
||||
|
||||
// Fbprop
|
||||
void BuildKNode(const GradParamPtr &grad_param, const VariableAdjointPtr &variable_adjoint);
|
||||
AnfNodePtr BuildKNode(const GradParamPtr &grad_param);
|
||||
void BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const ValuePtrList &op_args,
|
||||
std::vector<AnfNodePtr> *const node_list);
|
||||
AnfNodePtr BuildKNodeForCNodeInput(const AnfNodePtr &input_node);
|
||||
AnfNodePtr BuildKNodeForMakeTuple(const AnfNodePtr &input_node);
|
||||
AnfNodePtr BuildKNodeForTupleGetItem(const AnfNodePtr &input_node);
|
||||
};
|
||||
using AutoGradCellImplPtr = std::shared_ptr<AutoGradCellImpl>;
|
||||
|
||||
|
|
|
@ -142,9 +142,8 @@ ValuePtr ConvertOutputValueToTensor(const ValuePtr &v) {
|
|||
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||
return opt::CreateTupleTensor(value_tuple);
|
||||
}
|
||||
PyNativeAlgo::Common::FilterSensValues(v);
|
||||
MS_LOG(DEBUG) << "Output is value sequence, but have tensor and other type mixed. Its value is " << v->ToString();
|
||||
return v;
|
||||
return PyNativeAlgo::Common::FilterSensValues(v);
|
||||
} else if (v->isa<FloatImm>()) {
|
||||
double input_value = v->cast<FP32ImmPtr>()->value();
|
||||
return std::make_shared<tensor::Tensor>(input_value, kFloat32);
|
||||
|
@ -552,11 +551,10 @@ void GradExecutor::MakeNewTopGraph(const InputArgsInfoPtr &input_args_info) {
|
|||
input_args_info->already_run_cell_id, resource, fg);
|
||||
top_cell_->set_forward_already_run(true);
|
||||
top_cell_->set_input_args_id(input_args_info->input_args_id);
|
||||
top_cell_->set_is_cell_id_in_dynamic_detect_nodes_map(
|
||||
cell_id_with_dynamic_detect_nodes_.find(obj_id_with_grad_order) != cell_id_with_dynamic_detect_nodes_.end());
|
||||
PushHighOrderGraphStack(top_cell_);
|
||||
(void)top_cell_list_.emplace_back(top_cell_);
|
||||
|
||||
is_cell_id_in_dynamic_detect_nodes_map_ =
|
||||
(cell_id_with_dynamic_detect_nodes_.find(obj_id_with_grad_order) != cell_id_with_dynamic_detect_nodes_.end());
|
||||
MS_LOG(DEBUG) << "New top graph, fg ptr " << fg.get() << " resource ptr " << resource.get();
|
||||
}
|
||||
|
||||
|
@ -603,9 +601,7 @@ void GradExecutor::EndGraphInner(const py::object &obj, const py::object &out, c
|
|||
void GradExecutor::UpdateInputArgsInfo(const InputArgsInfoPtr &input_args_info, const py::object &obj,
|
||||
const py::object &out, const py::args &args) {
|
||||
MS_EXCEPTION_IF_NULL(input_args_info);
|
||||
if (input_args_info->has_custom_bprop) {
|
||||
GetCustomBpropPrim(obj, args, out, input_args_info);
|
||||
}
|
||||
GetCustomBpropPrim(obj, args, out, input_args_info);
|
||||
// Used at master thread, change its at master thread
|
||||
if (input_args_info->is_grad_topest_cell) {
|
||||
grad_flag_ = false;
|
||||
|
@ -678,27 +674,30 @@ void GradExecutor::DoGradForCustomBprop(const InputArgsInfoPtr &input_args_info,
|
|||
return;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Do grad for custom bprop";
|
||||
MS_EXCEPTION_IF_NULL(input_args_info->custom_bprp_prim);
|
||||
MS_EXCEPTION_IF_NULL(input_args_info->custom_bprop_prim);
|
||||
auto op_run_info = std::make_shared<FrontendOpRunInfo>();
|
||||
op_run_info->grad_flag = true;
|
||||
op_run_info->base_op_run_info.op_name = input_args_info->custom_bprp_prim->name();
|
||||
op_run_info->op_prim = input_args_info->custom_bprp_prim;
|
||||
op_run_info->base_op_run_info.op_name = input_args_info->custom_bprop_prim->name();
|
||||
op_run_info->op_prim = input_args_info->custom_bprop_prim;
|
||||
op_run_info->input_value = input_args_info->input_arg_value_vec;
|
||||
op_run_info->input_size = input_args_info->input_arg_value_vec.size();
|
||||
op_run_info->input_value_id = input_args_info->input_arg_id_vec;
|
||||
auto cnode = ConstructForwardGraph(op_run_info);
|
||||
if (!input_args_info->grad_is_running || bprop_grad_stack_.top().second) {
|
||||
DoOpGrad(op_run_info, cnode, input_args_info->out_value);
|
||||
}
|
||||
cnode->set_abstract(input_args_info->out_value->ToAbstract()->Broaden());
|
||||
DoOpGrad(op_run_info, cnode, input_args_info->out_value);
|
||||
(void)CheckGraphDynamic(cnode);
|
||||
SaveOutputNodeMap(out_id, op_run_info, cnode);
|
||||
}
|
||||
|
||||
void GradExecutor::GetCustomBpropPrim(const py::object &obj, const py::args &args, const py::object &out,
|
||||
const InputArgsInfoPtr &input_args_info) {
|
||||
MS_EXCEPTION_IF_NULL(input_args_info);
|
||||
if (!input_args_info->has_custom_bprop) {
|
||||
return;
|
||||
}
|
||||
custom_bprop_cell_count_ -= 1;
|
||||
input_args_info->custom_bprop_cell_count = custom_bprop_cell_count_;
|
||||
if (custom_bprop_cell_count_ != 0) {
|
||||
input_args_info->custom_bprop_cell_count -= 1;
|
||||
if (input_args_info->custom_bprop_cell_count != 0) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Get custom bprop prim";
|
||||
|
@ -744,7 +743,7 @@ void GradExecutor::GetCustomBpropPrim(const py::object &obj, const py::args &arg
|
|||
(void)input_args_info->input_arg_value_vec.emplace_back(PyNativeAlgo::DataConvert::PyObjToValue(args[i]));
|
||||
}
|
||||
}
|
||||
input_args_info->custom_bprp_prim = fake_prim;
|
||||
input_args_info->custom_bprop_prim = fake_prim;
|
||||
}
|
||||
|
||||
void GradExecutor::CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info) {
|
||||
|
@ -772,6 +771,7 @@ void GradExecutor::CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info
|
|||
EraseTopCellFromTopCellList(pre_top_cell);
|
||||
}
|
||||
already_run_top_cell_[already_top_cell_id] = new_top_cell;
|
||||
new_top_cell->set_force_top_cell_compile(false);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "No need to compile graph again";
|
||||
pre_top_cell->set_input_args_id(new_top_cell->input_args_id());
|
||||
|
@ -1042,8 +1042,7 @@ void GradExecutor::SetGradOrder(const std::string &obj_id) {
|
|||
// top_cell_->obj_id_with_grad_order() include obj_id and grad_order
|
||||
// If top_cell_->obj_id_with_grad_order().find(obj_id) == std::string::npos and have cell info stack, means current
|
||||
// cell is not top cell, grad high order come in
|
||||
if (top_cell_ == nullptr ||
|
||||
(!input_args_info_stack_.empty() && top_cell_->obj_id_with_grad_order().find(obj_id) == std::string::npos)) {
|
||||
if (top_cell_ == nullptr || top_cell_->obj_id_with_grad_order().find(obj_id) == std::string::npos) {
|
||||
IncreaseGradOrder();
|
||||
}
|
||||
if (!grad_is_running_) {
|
||||
|
@ -1107,7 +1106,11 @@ py::object GradExecutor::RunGradGraph() {
|
|||
MS_LOG(DEBUG) << "Eval run " << backend;
|
||||
grad_is_running_ = true;
|
||||
top_cell()->set_auto_grad_cell_ptr(nullptr);
|
||||
// In custom bprop, when running bprop function, top_input_args_info_ will be changed.
|
||||
// So, here copy and restore after running finished.
|
||||
auto top_input_args_info = top_input_args_info_;
|
||||
BaseRef out_value = (*run)(arg_list);
|
||||
top_input_args_info_ = top_input_args_info;
|
||||
grad_is_running_ = false;
|
||||
MS_LOG(DEBUG) << "Eval run end " << out_value.ToString();
|
||||
const auto &cur_run_bprop_graph = resource->func_graph();
|
||||
|
@ -1139,13 +1142,28 @@ void GradExecutor::MakeNestedCnode(bool has_custom_bprop, const std::vector<Valu
|
|||
|
||||
auto cnode = curr_g()->NewCNode(inputs);
|
||||
auto out_value = PyNativeAlgo::DataConvert::BaseRefToValue(out);
|
||||
// Get output values
|
||||
if (has_custom_bprop && !out_value->isa<ValueSequence>()) {
|
||||
std::vector<ValuePtr> out_v{out_value};
|
||||
out_value = std::make_shared<ValueTuple>(out_v);
|
||||
}
|
||||
const auto &out_id = PyNativeAlgo::Common::GetIdByValue(out_value);
|
||||
top_cell()->SetNodeMapInGraphInfoMap(out_id, cnode);
|
||||
cnode->set_abstract(out_value->ToAbstract()->Broaden());
|
||||
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString();
|
||||
|
||||
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString() << ", out id " << out_id;
|
||||
// High grad hit cache
|
||||
bool need_do_grad = true;
|
||||
if (!cur_vm_compile) {
|
||||
if (already_run_top_cell_.find(top_cell()->already_run_cell_id()) != already_run_top_cell_.end()) {
|
||||
const auto &dynamic_nodes = cell_id_with_dynamic_detect_nodes_[top_cell()->obj_id_with_grad_order()];
|
||||
MS_LOG(DEBUG) << "Cur op index " << (top_cell()->op_index() + 1) << ", outer graph all op size "
|
||||
<< dynamic_nodes.size();
|
||||
if (top_cell()->op_index() + 1 == dynamic_nodes.size()) {
|
||||
need_do_grad = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!use_dynamic_shape_process_ && !need_do_grad) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1170,11 +1188,7 @@ void GradExecutor::MakeNestedCnode(bool has_custom_bprop, const std::vector<Valu
|
|||
// Get input values
|
||||
ValuePtrList input_args(forward_args);
|
||||
(void)input_args.insert(input_args.end(), weights_args.cbegin(), weights_args.cend());
|
||||
// Get output values
|
||||
if (has_custom_bprop && !out_value->isa<ValueSequence>()) {
|
||||
std::vector<ValuePtr> out_v{out_value};
|
||||
out_value = std::make_shared<ValueTuple>(out_v);
|
||||
}
|
||||
|
||||
auto grad_param = std::make_shared<ad::GradParam>(cnode, input_args, out_value, second_grad_fg,
|
||||
!top_cell()->is_high_order_top_cell());
|
||||
if (!top_cell()->auto_grad_cell_ptr()->KPynativeWithFProp(grad_param)) {
|
||||
|
@ -1285,7 +1299,6 @@ void GradExecutor::ClearRes() {
|
|||
need_renormalize_ = false;
|
||||
eliminate_forward_ = true;
|
||||
use_dynamic_shape_process_ = false;
|
||||
is_cell_id_in_dynamic_detect_nodes_map_ = false;
|
||||
custom_bprop_cell_count_ = 0;
|
||||
grad_order_ = 0;
|
||||
top_cell_ = nullptr;
|
||||
|
@ -1388,6 +1401,7 @@ AnfNodePtr GradExecutor::GetValueSequenceInput(const ValuePtr &v, const std::str
|
|||
return nullptr;
|
||||
}
|
||||
ValuePtrList input_args;
|
||||
abstract::AbstractBasePtrList abs_list;
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
const auto &obj_tuple = v->cast<ValueSequencePtr>();
|
||||
const auto &v_list = obj_tuple->value();
|
||||
|
@ -1400,10 +1414,12 @@ AnfNodePtr GradExecutor::GetValueSequenceInput(const ValuePtr &v, const std::str
|
|||
(void)input_args.emplace_back(v_arg);
|
||||
const std::string &id = PyNativeAlgo::Common::GetIdByValue(v_arg);
|
||||
(void)inputs.emplace_back(GetInput(v_arg, id));
|
||||
(void)abs_list.emplace_back(v_arg->ToAbstract()->Broaden());
|
||||
(void)GetValueSequenceInput(v_arg, id);
|
||||
}
|
||||
// Create make tuple node and record to graph info map.
|
||||
auto cnode = curr_g()->NewCNode(inputs);
|
||||
cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
|
||||
MS_LOG(DEBUG) << "Create make tuple node: " << cnode->DebugString();
|
||||
(void)CheckGraphDynamic(cnode);
|
||||
top_cell()->SetNodeMapInGraphInfoMap(obj_id, cnode, -1, false);
|
||||
|
@ -1412,9 +1428,9 @@ AnfNodePtr GradExecutor::GetValueSequenceInput(const ValuePtr &v, const std::str
|
|||
|
||||
AnfNodePtr GradExecutor::CreateTupleGetItemNode(const std::string &obj_id,
|
||||
const std::pair<AnfNodePtr, std::vector<int64_t>> &out) const {
|
||||
MS_LOG(DEBUG) << "Output size: " << out.second.size();
|
||||
auto c_node = out.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(c_node);
|
||||
MS_LOG(DEBUG) << "Cnode " << c_node->DebugString() << ", id " << obj_id << ", out second " << out.second;
|
||||
auto abs = c_node->abstract();
|
||||
// Create tuple get item node
|
||||
for (const auto &idx : out.second) {
|
||||
|
@ -1430,12 +1446,11 @@ AnfNodePtr GradExecutor::CreateTupleGetItemNode(const std::string &obj_id,
|
|||
}
|
||||
auto prim_abs = elements[static_cast<size_t>(idx)];
|
||||
MS_EXCEPTION_IF_NULL(prim_abs);
|
||||
MS_LOG(DEBUG) << "Set tuple getitem abs " << prim_abs->ToString();
|
||||
c_node->set_abstract(prim_abs);
|
||||
}
|
||||
}
|
||||
(void)CheckGraphDynamic(c_node);
|
||||
MS_LOG(DEBUG) << "Get input node " << c_node->ToString() << ", id " << obj_id;
|
||||
MS_LOG(DEBUG) << "Create tuple getitem node " << c_node->ToString() << ", abs " << c_node->abstract()->ToString();
|
||||
return c_node;
|
||||
}
|
||||
|
||||
|
@ -1524,7 +1539,10 @@ void GradExecutor::SaveOutputNodeMap(const std::string &obj_id, const FrontendOp
|
|||
void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode,
|
||||
const ValuePtr &op_out) const {
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
|
||||
if (grad_is_running_ && !bprop_grad_stack_.top().second) {
|
||||
MS_LOG(DEBUG) << "Custom bprop, no need do op grad";
|
||||
return;
|
||||
}
|
||||
// to avoid out exist in tape bprop, avoid out be modified.
|
||||
ValuePtrList cloned_op_args;
|
||||
(void)std::transform(op_run_info->input_value.begin(), op_run_info->input_value.end(),
|
||||
|
@ -1826,7 +1844,7 @@ void GradExecutor::SaveDynamicDetectNodeInfoInFirstTime(const AnfNodePtr &anf_no
|
|||
bool GradExecutor::IsGraphDynamic(const AnfNodePtr &anf_node, const size_t node_idx, bool is_ms_function_node,
|
||||
const std::string &graph_phase) const {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
if (!is_cell_id_in_dynamic_detect_nodes_map_) {
|
||||
if (!top_cell()->is_cell_id_in_dynamic_detect_nodes_map()) {
|
||||
SaveDynamicDetectNodeInfoInFirstTime(anf_node, node_idx, is_ms_function_node, graph_phase);
|
||||
// The net is regarded as a static net by default in the first time.
|
||||
return false;
|
||||
|
|
|
@ -200,7 +200,7 @@ class GradExecutor {
|
|||
bool eliminate_forward_{true};
|
||||
mutable bool use_dynamic_shape_process_{false};
|
||||
mutable bool is_cell_id_in_dynamic_detect_nodes_map_{false};
|
||||
int custom_bprop_cell_count_{0};
|
||||
size_t custom_bprop_cell_count_{0};
|
||||
size_t obj_order_{0};
|
||||
// If grad_order=1, indicate first derivative; grad_order=2, indicate second derivative; ...
|
||||
size_t grad_order_{0};
|
||||
|
|
|
@ -89,6 +89,10 @@ class TopCellInfo {
|
|||
inline bool is_high_order_top_cell() const { return is_high_order_top_cell_; }
|
||||
inline void set_need_do_final_opt(bool need_do_final_opt) { need_do_final_opt_ = need_do_final_opt; }
|
||||
inline bool need_do_final_opt() const { return need_do_final_opt_; }
|
||||
inline void set_is_cell_id_in_dynamic_detect_nodes_map(bool is_cell_id_in_dynamic_detect_nodes_map) {
|
||||
is_cell_id_in_dynamic_detect_nodes_map_ = is_cell_id_in_dynamic_detect_nodes_map;
|
||||
}
|
||||
inline bool is_cell_id_in_dynamic_detect_nodes_map() const { return is_cell_id_in_dynamic_detect_nodes_map_; }
|
||||
inline pipeline::ResourcePtr resource() const { return resource_; }
|
||||
inline FuncGraphPtr fg() const {
|
||||
MS_EXCEPTION_IF_NULL(fg_);
|
||||
|
@ -162,8 +166,9 @@ class TopCellInfo {
|
|||
bool force_top_cell_compile_{false};
|
||||
bool is_high_order_top_cell_{false};
|
||||
bool need_do_final_opt_{false};
|
||||
size_t op_index_{0};
|
||||
bool is_cell_id_in_dynamic_detect_nodes_map_{false};
|
||||
size_t grad_order_{0};
|
||||
size_t op_index_{0};
|
||||
std::string obj_id_with_grad_order_;
|
||||
std::string cell_id_;
|
||||
std::string already_run_cell_id_;
|
||||
|
|
|
@ -190,7 +190,7 @@ def train_process_bert_thor(q, device_id, epoch_size, device_num):
|
|||
q.put({'loss': loss_list, 'cost': per_step_mseconds})
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_single
|
||||
|
|
Loading…
Reference in New Issue