forked from mindspore-Ecosystem/mindspore
!46233 Fix bug for PyNative high grad
Merge pull request !46233 from zjun/fix_ms_function
This commit is contained in:
commit
2c538003fa
|
@ -33,6 +33,7 @@
|
||||||
#include "utils/profile.h"
|
#include "utils/profile.h"
|
||||||
#include "include/common/utils/primitive_utils.h"
|
#include "include/common/utils/primitive_utils.h"
|
||||||
#include "pipeline/jit/pass.h"
|
#include "pipeline/jit/pass.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ad {
|
namespace ad {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -234,6 +235,7 @@ bool IsZerosLikeNode(const AnfNodePtr &node) {
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr OptimizeBpropBuilder(const FuncGraphPtr &bprop_func_graph) {
|
FuncGraphPtr OptimizeBpropBuilder(const FuncGraphPtr &bprop_func_graph) {
|
||||||
|
pynative::PyNativeAlgo::Common::DumpGraphIR("bprop_builder_before_opt.ir", bprop_func_graph);
|
||||||
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
|
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
|
||||||
resource->set_func_graph(bprop_func_graph);
|
resource->set_func_graph(bprop_func_graph);
|
||||||
auto manager = resource->manager();
|
auto manager = resource->manager();
|
||||||
|
@ -243,6 +245,34 @@ FuncGraphPtr OptimizeBpropBuilder(const FuncGraphPtr &bprop_func_graph) {
|
||||||
pynative::PyNativeAlgo::Common::DumpGraphIR("bprop_builder_after_opt.ir", after_opt_bg);
|
pynative::PyNativeAlgo::Common::DumpGraphIR("bprop_builder_after_opt.ir", after_opt_bg);
|
||||||
return after_opt_bg;
|
return after_opt_bg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsOutputBothEmpty(const AnfNodePtr &inputs_grad, const AnfNodePtr &weights_grad) {
|
||||||
|
if (!inputs_grad->isa<CNode>() || !weights_grad->isa<CNode>()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto inputs_grad_cnode = inputs_grad->cast<CNodePtr>();
|
||||||
|
auto weights_grad_cnode = weights_grad->cast<CNodePtr>();
|
||||||
|
if (!IsPrimitiveCNode(inputs_grad_cnode, prim::kPrimMakeTuple) ||
|
||||||
|
!IsPrimitiveCNode(weights_grad_cnode, prim::kPrimMakeTuple)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
constexpr int kEmptyTupeSize = 1;
|
||||||
|
if (inputs_grad_cnode->size() != kEmptyTupeSize || weights_grad_cnode->size() != kEmptyTupeSize) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr GenerateEmptyTupleValue() {
|
||||||
|
std::vector<ValuePtr> value_list;
|
||||||
|
auto inputs_value = std::make_shared<ValueTuple>(value_list);
|
||||||
|
auto weights_value = std::make_shared<ValueTuple>(value_list);
|
||||||
|
std::vector<ValuePtr> tuple_list{inputs_value, weights_value};
|
||||||
|
auto tuple_value = std::make_shared<ValueTuple>(tuple_list);
|
||||||
|
auto tuple_value_node = NewValueNode(tuple_value);
|
||||||
|
tuple_value_node->set_abstract(tuple_value->ToAbstract());
|
||||||
|
return tuple_value_node;
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
AnfNodePtr FunctionNode::HyperAdd(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
|
AnfNodePtr FunctionNode::HyperAdd(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
|
||||||
|
@ -302,16 +332,15 @@ void FunctionNode::ReplaceEdges() {
|
||||||
AutoGradCellImpl::AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std::vector<ValuePtr> &input_param_values)
|
AutoGradCellImpl::AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std::vector<ValuePtr> &input_param_values)
|
||||||
: tape_(std::make_shared<FuncGraph>()), cell_inputs_(cell_inputs) {
|
: tape_(std::make_shared<FuncGraph>()), cell_inputs_(cell_inputs) {
|
||||||
tape_->debug_info()->set_name("grad_top");
|
tape_->debug_info()->set_name("grad_top");
|
||||||
MS_LOG(DEBUG) << "Start AutoGradCellImpl: "
|
MS_LOG(DEBUG) << "Start AutoGradCellImpl, cell_inputs size: " << cell_inputs.size();
|
||||||
<< "cell_inputs size: " << cell_inputs.size();
|
|
||||||
for (size_t i = 0; i < cell_inputs.size(); ++i) {
|
for (size_t i = 0; i < cell_inputs.size(); ++i) {
|
||||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(cell_inputs[i]->debug_info()));
|
TraceGuard trace_guard(std::make_shared<TraceCopy>(cell_inputs[i]->debug_info()));
|
||||||
auto parameter = tape_->add_parameter();
|
auto parameter = tape_->add_parameter();
|
||||||
parameter->set_abstract(input_param_values[i]->ToAbstract()->Broaden());
|
parameter->set_abstract(input_param_values[i]->ToAbstract()->Broaden());
|
||||||
auto zeros_like_dout = BuildZerosLikeNode(tape_, input_param_values[i]);
|
auto zeros_like_dout = BuildZerosLikeNode(tape_, input_param_values[i]);
|
||||||
auto func_node = std::make_shared<FunctionNode>(tape_, zeros_like_dout);
|
auto func_node = std::make_shared<FunctionNode>(tape_, zeros_like_dout);
|
||||||
auto input_adjoint = std::make_shared<VariableNode>(func_node, input_param_values[i]);
|
auto input_adjoint = std::make_shared<VariableAdjoint>(func_node, input_param_values[i]);
|
||||||
anfnode_to_variable_adjoint_.insert(std::make_pair(cell_inputs[i], input_adjoint));
|
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(cell_inputs[i], input_adjoint));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -332,13 +361,14 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
|
||||||
ClearDeviceAddress(cloned_value);
|
ClearDeviceAddress(cloned_value);
|
||||||
AnfNodePtr dout = BuildSpecialLikeValue(tape_, cloned_value, SpecialType::kZerosLikeType);
|
AnfNodePtr dout = BuildSpecialLikeValue(tape_, cloned_value, SpecialType::kZerosLikeType);
|
||||||
auto fn = std::make_shared<FunctionNode>(tape_, dout);
|
auto fn = std::make_shared<FunctionNode>(tape_, dout);
|
||||||
auto variable_adjoint = std::make_shared<VariableNode>(fn, cloned_value);
|
auto variable_adjoint = std::make_shared<VariableAdjoint>(fn, cloned_value);
|
||||||
if (!grad_param->grad_by_value) {
|
if (!grad_param->grad_by_value) {
|
||||||
BuildKNode(grad_param, variable_adjoint);
|
BuildKNode(grad_param, variable_adjoint);
|
||||||
need_do_manager_replace_ = true;
|
need_do_manager_replace_ = true;
|
||||||
}
|
}
|
||||||
CNodePtr input_node = ConstructBpropGraphInput(grad_param, dout);
|
CNodePtr input_node = ConstructBpropGraphInput(grad_param, dout, variable_adjoint);
|
||||||
MS_LOG(DEBUG) << "Construct input cnode: " << input_node->DebugString();
|
MS_LOG(DEBUG) << "Construct input cnode: " << input_node->DebugString();
|
||||||
|
// Gradient outputs
|
||||||
std::vector<CNodePtr> outputs;
|
std::vector<CNodePtr> outputs;
|
||||||
#ifndef ENABLE_TEST
|
#ifndef ENABLE_TEST
|
||||||
if (IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook)) {
|
if (IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook)) {
|
||||||
|
@ -364,7 +394,7 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
|
||||||
variable_adjoint->set_is_fake_bprop(true);
|
variable_adjoint->set_is_fake_bprop(true);
|
||||||
variable_adjoint->set_fake_prim_name(prim->name());
|
variable_adjoint->set_fake_prim_name(prim->name());
|
||||||
}
|
}
|
||||||
anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
|
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
|
||||||
// record last_node for brackpropagate
|
// record last_node for brackpropagate
|
||||||
last_node_ = grad_param->cnode;
|
last_node_ = grad_param->cnode;
|
||||||
return true;
|
return true;
|
||||||
|
@ -372,9 +402,9 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
|
||||||
|
|
||||||
bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
|
bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
|
||||||
MS_EXCEPTION_IF_NULL(grad_param);
|
MS_EXCEPTION_IF_NULL(grad_param);
|
||||||
|
|
||||||
AnfNodePtrList args_node_list;
|
AnfNodePtrList args_node_list;
|
||||||
CNodePtr bprop_cnode = nullptr;
|
CNodePtr bprop_cnode = nullptr;
|
||||||
AnfNodePtr k_node = nullptr;
|
|
||||||
AnfNodePtr dout = nullptr;
|
AnfNodePtr dout = nullptr;
|
||||||
if (grad_param->grad_by_value) {
|
if (grad_param->grad_by_value) {
|
||||||
for (size_t i = 0; i < grad_param->op_args.size(); ++i) {
|
for (size_t i = 0; i < grad_param->op_args.size(); ++i) {
|
||||||
|
@ -395,7 +425,11 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
|
||||||
BuildKNodeListFromPrimalCNode(grad_param->cnode, grad_param->op_args, &args_node_list);
|
BuildKNodeListFromPrimalCNode(grad_param->cnode, grad_param->op_args, &args_node_list);
|
||||||
bprop_cnode = GetBPropFromFProp(grad_param->fprop_fg, args_node_list, grad_param->out, &dout);
|
bprop_cnode = GetBPropFromFProp(grad_param->fprop_fg, args_node_list, grad_param->out, &dout);
|
||||||
}
|
}
|
||||||
|
auto fn = std::make_shared<FunctionNode>(tape_, dout);
|
||||||
|
auto variable_adjoint = std::make_shared<VariableAdjoint>(fn, grad_param->out);
|
||||||
|
if (!grad_param->grad_by_value) {
|
||||||
|
BuildKNode(grad_param, variable_adjoint);
|
||||||
|
}
|
||||||
std::vector<CNodePtr> outputs;
|
std::vector<CNodePtr> outputs;
|
||||||
for (size_t i = 1; i < grad_param->cnode->size(); ++i) {
|
for (size_t i = 1; i < grad_param->cnode->size(); ++i) {
|
||||||
// bprop_app[0] env
|
// bprop_app[0] env
|
||||||
|
@ -403,11 +437,8 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
|
||||||
din->set_abstract(grad_param->op_args[i - 1]->ToAbstract()->Broaden());
|
din->set_abstract(grad_param->op_args[i - 1]->ToAbstract()->Broaden());
|
||||||
(void)outputs.emplace_back(din);
|
(void)outputs.emplace_back(din);
|
||||||
}
|
}
|
||||||
auto fn = std::make_shared<FunctionNode>(tape_, dout);
|
|
||||||
auto variable_adjoint = std::make_shared<VariableNode>(fn, grad_param->out);
|
|
||||||
variable_adjoint->set_k_node(k_node);
|
|
||||||
UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args);
|
UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args);
|
||||||
anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
|
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
|
||||||
need_do_manager_replace_ = true;
|
need_do_manager_replace_ = true;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -430,7 +461,7 @@ CNodePtr AutoGradCellImpl::GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const
|
||||||
auto get_bprop =
|
auto get_bprop =
|
||||||
bprop_builder->NewCNode({NewValueNode(prim::kPrimTupleGetItem), fprop_app, NewValueNode(static_cast<int64_t>(1))});
|
bprop_builder->NewCNode({NewValueNode(prim::kPrimTupleGetItem), fprop_app, NewValueNode(static_cast<int64_t>(1))});
|
||||||
|
|
||||||
// Get graph after optimize
|
// Get bprop from fprop_fg, it is 2th output of fprop_fg
|
||||||
AnfNodePtrList node_list{get_bprop};
|
AnfNodePtrList node_list{get_bprop};
|
||||||
auto dout = bprop_builder->add_parameter();
|
auto dout = bprop_builder->add_parameter();
|
||||||
MS_EXCEPTION_IF_NULL(out);
|
MS_EXCEPTION_IF_NULL(out);
|
||||||
|
@ -438,12 +469,15 @@ CNodePtr AutoGradCellImpl::GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const
|
||||||
(void)node_list.emplace_back(dout);
|
(void)node_list.emplace_back(dout);
|
||||||
auto call_bprop = bprop_builder->NewCNode(node_list);
|
auto call_bprop = bprop_builder->NewCNode(node_list);
|
||||||
bprop_builder->set_output(call_bprop);
|
bprop_builder->set_output(call_bprop);
|
||||||
|
|
||||||
|
// Call pass for optimize graph, such as inline
|
||||||
auto after_opt_fg = OptimizeBpropBuilder(bprop_builder);
|
auto after_opt_fg = OptimizeBpropBuilder(bprop_builder);
|
||||||
|
|
||||||
// Call by tape_
|
// Call by tape_
|
||||||
MS_EXCEPTION_IF_NULL(tape_dout);
|
MS_EXCEPTION_IF_NULL(tape_dout);
|
||||||
*tape_dout = BuildZerosLikeNode(tape_, out);
|
*tape_dout = BuildZerosLikeNode(tape_, out);
|
||||||
(void)bprop_builder_inputs.emplace_back(*tape_dout);
|
(void)bprop_builder_inputs.emplace_back(*tape_dout);
|
||||||
bprop_builder_inputs.insert(bprop_builder_inputs.cbegin(), NewValueNode(after_opt_fg));
|
(void)bprop_builder_inputs.insert(bprop_builder_inputs.cbegin(), NewValueNode(after_opt_fg));
|
||||||
get_bprop = tape_->NewCNode(bprop_builder_inputs);
|
get_bprop = tape_->NewCNode(bprop_builder_inputs);
|
||||||
// tape_dout is set by next op
|
// tape_dout is set by next op
|
||||||
AddUser(*tape_dout, get_bprop, bprop_builder_inputs.size() - 1);
|
AddUser(*tape_dout, get_bprop, bprop_builder_inputs.size() - 1);
|
||||||
|
@ -467,45 +501,46 @@ FuncGraphPtr AutoGradCellImpl::Finish(const AnfNodePtrList &weights, const std::
|
||||||
if (!last_node_->isa<ValueNode>() && !last_node_->isa<Parameter>()) {
|
if (!last_node_->isa<ValueNode>() && !last_node_->isa<Parameter>()) {
|
||||||
(void)BackPropagate();
|
(void)BackPropagate();
|
||||||
}
|
}
|
||||||
// Return the gradient;
|
|
||||||
if (grad_attr.get_by_position && grad_position.empty()) {
|
|
||||||
MS_LOG(EXCEPTION) << "grad_position should not be empty when grad by position!";
|
|
||||||
}
|
|
||||||
SetOutput(weights, grad_position, grad_attr);
|
SetOutput(weights, grad_position, grad_attr);
|
||||||
// Replace Parameter of primal funcgraph with parameter of tape_;
|
// Replace Parameter of primal funcgraph with parameter of tape_;
|
||||||
ReplacePrimalParameter(weights, grad_attr.has_sens);
|
ReplacePrimalParameter(weights, grad_attr.has_sens);
|
||||||
pynative::PyNativeAlgo::Common::DumpGraphIR("before_final_opt.ir", tape_);
|
pynative::PyNativeAlgo::Common::DumpGraphIR("before_final_opt.ir", tape_);
|
||||||
return tape_;
|
return tape_;
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodePtr AutoGradCellImpl::ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout) {
|
CNodePtr AutoGradCellImpl::ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout,
|
||||||
|
const VariableAdjointPtr &variable_adjoint) {
|
||||||
MS_EXCEPTION_IF_NULL(grad_param);
|
MS_EXCEPTION_IF_NULL(grad_param);
|
||||||
std::vector<AnfNodePtr> node_list;
|
std::vector<AnfNodePtr> node_list;
|
||||||
(void)node_list.emplace_back(grad_param->cnode->input(0));
|
(void)node_list.emplace_back(grad_param->cnode->input(0));
|
||||||
|
auto out_abs = grad_param->out->ToAbstract()->Broaden();
|
||||||
if (grad_param->grad_by_value) {
|
if (grad_param->grad_by_value) {
|
||||||
for (size_t i = 0; i < grad_param->op_args.size(); ++i) {
|
for (size_t i = 0; i < grad_param->op_args.size(); ++i) {
|
||||||
const auto &v = grad_param->op_args[i];
|
const auto &v = grad_param->op_args[i];
|
||||||
auto node = grad_param->cnode->input(i + 1);
|
auto node = grad_param->cnode->input(i + 1);
|
||||||
if (node->isa<Parameter>()) {
|
if (node->isa<Parameter>()) {
|
||||||
node_list.emplace_back(node);
|
(void)node_list.emplace_back(node);
|
||||||
node->set_abstract(v->ToAbstract());
|
node->set_abstract(v->ToAbstract()->Broaden());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto v_node = NewValueNode(grad_param->op_args[i]);
|
auto v_node = NewValueNode(grad_param->op_args[i]);
|
||||||
v_node->set_abstract(grad_param->op_args[i]->ToAbstract());
|
v_node->set_abstract(grad_param->op_args[i]->ToAbstract()->Broaden());
|
||||||
node_list.emplace_back(v_node);
|
(void)node_list.emplace_back(v_node);
|
||||||
}
|
}
|
||||||
|
// Set out
|
||||||
|
auto out_node = NewValueNode(grad_param->out);
|
||||||
|
out_node->set_abstract(out_abs);
|
||||||
|
(void)node_list.emplace_back(out_node);
|
||||||
} else {
|
} else {
|
||||||
// Input is a Parameter or cnode, not a value node
|
// Input is a Parameter or cnode, not a value node
|
||||||
BuildKNodeListFromPrimalCNode(grad_param->cnode, grad_param->op_args, &node_list);
|
BuildKNodeListFromPrimalCNode(grad_param->cnode, grad_param->op_args, &node_list);
|
||||||
|
// Set out
|
||||||
|
MS_EXCEPTION_IF_NULL(variable_adjoint);
|
||||||
|
(void)node_list.emplace_back(variable_adjoint->k_node());
|
||||||
}
|
}
|
||||||
auto out_node = NewValueNode(grad_param->out);
|
// Set dout
|
||||||
auto out_abs = grad_param->out->ToAbstract()->Broaden();
|
(void)node_list.emplace_back(dout);
|
||||||
out_node->set_abstract(out_abs);
|
|
||||||
// set out
|
|
||||||
node_list.emplace_back(out_node);
|
|
||||||
// set dout
|
|
||||||
node_list.emplace_back(dout);
|
|
||||||
auto input_node = tape_->NewCNode(node_list);
|
auto input_node = tape_->NewCNode(node_list);
|
||||||
input_node->set_abstract(out_abs);
|
input_node->set_abstract(out_abs);
|
||||||
return input_node;
|
return input_node;
|
||||||
|
@ -515,7 +550,7 @@ void AutoGradCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, cons
|
||||||
std::vector<AnfNodePtr> *const node_list) {
|
std::vector<AnfNodePtr> *const node_list) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||||
MS_LOG(DEBUG) << "Find input knode of node " << cnode->input(i)->DebugString();
|
MS_LOG(DEBUG) << "Get knode for node " << cnode->input(i)->DebugString();
|
||||||
if (cnode->input(i)->isa<CNode>()) {
|
if (cnode->input(i)->isa<CNode>()) {
|
||||||
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(cnode->input(i));
|
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(cnode->input(i));
|
||||||
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
|
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
|
||||||
|
@ -524,13 +559,13 @@ void AutoGradCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, cons
|
||||||
MS_EXCEPTION_IF_NULL(input_adjoint_iter->second->k_node());
|
MS_EXCEPTION_IF_NULL(input_adjoint_iter->second->k_node());
|
||||||
(void)node_list->emplace_back(input_adjoint_iter->second->k_node());
|
(void)node_list->emplace_back(input_adjoint_iter->second->k_node());
|
||||||
} else {
|
} else {
|
||||||
cnode->input(i)->set_abstract(op_args[i - 1]->ToAbstract());
|
cnode->input(i)->set_abstract(op_args[i - 1]->ToAbstract()->Broaden());
|
||||||
(void)node_list->emplace_back(cnode->input(i));
|
(void)node_list->emplace_back(cnode->input(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param, const VariableNodePtr &VariableNode) {
|
void AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param, const VariableAdjointPtr &variable_adjoint) {
|
||||||
MS_EXCEPTION_IF_NULL(grad_param);
|
MS_EXCEPTION_IF_NULL(grad_param);
|
||||||
AnfNodePtrList node_list;
|
AnfNodePtrList node_list;
|
||||||
for (size_t i = 0; i < grad_param->cnode->inputs().size(); ++i) {
|
for (size_t i = 0; i < grad_param->cnode->inputs().size(); ++i) {
|
||||||
|
@ -538,7 +573,8 @@ void AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param, const Variable
|
||||||
}
|
}
|
||||||
auto k_node = tape_->NewCNode(node_list);
|
auto k_node = tape_->NewCNode(node_list);
|
||||||
k_node->set_abstract(grad_param->out->ToAbstract()->Broaden());
|
k_node->set_abstract(grad_param->out->ToAbstract()->Broaden());
|
||||||
VariableNode->set_k_node(k_node);
|
variable_adjoint->set_k_node(k_node);
|
||||||
|
MS_LOG(DEBUG) << "Build knode " << k_node->DebugString();
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const AnfNodePtr &input_node) {
|
AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const AnfNodePtr &input_node) {
|
||||||
|
@ -565,8 +601,9 @@ void AutoGradCellImpl::UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr
|
||||||
MS_LOG(EXCEPTION) << "The size of dins is not same as op_args";
|
MS_LOG(EXCEPTION) << "The size of dins is not same as op_args";
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < op_args.size(); ++i) {
|
for (size_t i = 0; i < op_args.size(); ++i) {
|
||||||
auto node = cnode->input(i + 1);
|
const auto &node = cnode->input(i + 1);
|
||||||
auto din = dins[i];
|
const auto &din = dins[i];
|
||||||
|
MS_LOG(DEBUG) << "Node " << node->DebugString() << ", din " << din->DebugString();
|
||||||
UpdateNextEdges(fn, node, din, op_args[i]);
|
UpdateNextEdges(fn, node, din, op_args[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -617,29 +654,27 @@ void AutoGradCellImpl::UpdateNextEdges(const FunctionNodePtr &fn, const AnfNodeP
|
||||||
AddParameterNode(param, tensor);
|
AddParameterNode(param, tensor);
|
||||||
UpdateNextEdges(fn, node, din, op_arg);
|
UpdateNextEdges(fn, node, din, op_arg);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "It is not a cnode: " << node->DebugString();
|
MS_LOG(DEBUG) << "It is not a cnode or parameter: " << node->DebugString();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AutoGradCellImpl::BuildForwardLastNode() {
|
void AutoGradCellImpl::BuildForwardLastNode() {
|
||||||
|
MS_EXCEPTION_IF_NULL(last_node_);
|
||||||
if (last_node_->isa<ValueNode>() ||
|
if (last_node_->isa<ValueNode>() ||
|
||||||
anfnode_to_variable_adjoint_.find(last_node_) != anfnode_to_variable_adjoint_.end()) {
|
anfnode_to_variable_adjoint_.find(last_node_) != anfnode_to_variable_adjoint_.end()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (anfnode_to_variable_adjoint_.find(last_node_) == anfnode_to_variable_adjoint_.end()) {
|
MS_LOG(DEBUG) << "Process last node info " << last_node_->DebugString();
|
||||||
auto zeros_like_node = BuildZerosLikeNode(tape_, sens_value_);
|
auto zeros_like_node = BuildZerosLikeNode(tape_, sens_value_);
|
||||||
auto fn = std::make_shared<FunctionNode>(tape_, zeros_like_node);
|
auto fn = std::make_shared<FunctionNode>(tape_, zeros_like_node);
|
||||||
// If last_node is a maketuple or tuplegetitem, need update next edges,
|
// If last_node is a maketuple or tuplegetitem, need update next edges,
|
||||||
// if last_node is parameter, not need to update next edges.
|
// if last_node is parameter, not need to update next edges.
|
||||||
if (last_node_->isa<CNode>()) {
|
if (last_node_->isa<CNode>()) {
|
||||||
UpdateNextEdges(fn, last_node_, zeros_like_node, sens_value_);
|
UpdateNextEdges(fn, last_node_, zeros_like_node, sens_value_);
|
||||||
}
|
|
||||||
auto input_adjoint = std::make_shared<VariableNode>(fn, sens_value_);
|
|
||||||
anfnode_to_variable_adjoint_.insert(std::make_pair(last_node_, input_adjoint));
|
|
||||||
} else {
|
|
||||||
MS_LOG(EXCEPTION) << "Unprocessed node: " << last_node_->DebugString();
|
|
||||||
}
|
}
|
||||||
|
auto input_adjoint = std::make_shared<VariableAdjoint>(fn, sens_value_);
|
||||||
|
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(last_node_, input_adjoint));
|
||||||
}
|
}
|
||||||
|
|
||||||
void AutoGradCellImpl::AddParameterNode(const AnfNodePtr ¶meter, const ValuePtr &tensor) {
|
void AutoGradCellImpl::AddParameterNode(const AnfNodePtr ¶meter, const ValuePtr &tensor) {
|
||||||
|
@ -647,9 +682,9 @@ void AutoGradCellImpl::AddParameterNode(const AnfNodePtr ¶meter, const Value
|
||||||
MS_EXCEPTION_IF_NULL(tensor);
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
auto zeros_like_dout = BuildZerosLikeNode(tape_, tensor);
|
auto zeros_like_dout = BuildZerosLikeNode(tape_, tensor);
|
||||||
auto func_node = std::make_shared<FunctionNode>(tape_, zeros_like_dout);
|
auto func_node = std::make_shared<FunctionNode>(tape_, zeros_like_dout);
|
||||||
auto input_adjoint = std::make_shared<VariableNode>(func_node, tensor);
|
auto input_adjoint = std::make_shared<VariableAdjoint>(func_node, tensor);
|
||||||
anfnode_to_variable_adjoint_.insert(std::make_pair(parameter, input_adjoint));
|
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(parameter, input_adjoint));
|
||||||
weights_.push_back(parameter);
|
(void)weights_used_in_graph_.emplace_back(parameter);
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePtr &out_value, const ValuePtr &sub_value,
|
AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePtr &out_value, const ValuePtr &sub_value,
|
||||||
|
@ -657,8 +692,8 @@ AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePt
|
||||||
MS_EXCEPTION_IF_NULL(out_value);
|
MS_EXCEPTION_IF_NULL(out_value);
|
||||||
MS_EXCEPTION_IF_NULL(sub_value);
|
MS_EXCEPTION_IF_NULL(sub_value);
|
||||||
MS_EXCEPTION_IF_NULL(din);
|
MS_EXCEPTION_IF_NULL(din);
|
||||||
std::string out_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(out_value);
|
const auto &out_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(out_value);
|
||||||
std::string sub_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(sub_value);
|
const auto &sub_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(sub_value);
|
||||||
if (out_value_id == sub_value_id) {
|
if (out_value_id == sub_value_id) {
|
||||||
return din;
|
return din;
|
||||||
} else if (out_value->isa<tensor::Tensor>()) {
|
} else if (out_value->isa<tensor::Tensor>()) {
|
||||||
|
@ -672,13 +707,13 @@ AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePt
|
||||||
}
|
}
|
||||||
auto value_seq = out_value->cast<ValueSequencePtr>();
|
auto value_seq = out_value->cast<ValueSequencePtr>();
|
||||||
int index = -1;
|
int index = -1;
|
||||||
for (auto value : value_seq->value()) {
|
for (const auto &value : value_seq->value()) {
|
||||||
auto real_din = GetRealDin(fn, value, sub_value, din);
|
auto real_din = GetRealDin(fn, value, sub_value, din);
|
||||||
(void)inputs.emplace_back(real_din);
|
(void)inputs.emplace_back(real_din);
|
||||||
|
|
||||||
// if exist din == fake_dout, we record it in user vector
|
// if exist din == fake_dout, we record it in user vector
|
||||||
if (din == fn->fake_dout() && real_din == din) {
|
if (din == fn->fake_dout() && real_din == din) {
|
||||||
index = inputs.size() - 1;
|
index = static_cast<int>(inputs.size()) - 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto new_din = tape_->NewCNode(inputs);
|
auto new_din = tape_->NewCNode(inputs);
|
||||||
|
@ -704,7 +739,7 @@ void AutoGradCellImpl::BuildBPropCutCNode(const CNodePtr &cnode, std::vector<CNo
|
||||||
prim_py->AddBpropCutPrim(bprop_cut);
|
prim_py->AddBpropCutPrim(bprop_cut);
|
||||||
if (prim->HasAttr("cell_id")) {
|
if (prim->HasAttr("cell_id")) {
|
||||||
auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
|
auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
|
||||||
if (cell_id != "") {
|
if (!cell_id.empty()) {
|
||||||
(void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
|
(void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
|
||||||
(void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id));
|
(void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id));
|
||||||
}
|
}
|
||||||
|
@ -723,12 +758,11 @@ void AutoGradCellImpl::BuildBPropCutCNode(const CNodePtr &cnode, std::vector<CNo
|
||||||
if (i < args_size) {
|
if (i < args_size) {
|
||||||
auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), output, NewValueNode(SizeToLong(i - 1))});
|
auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), output, NewValueNode(SizeToLong(i - 1))});
|
||||||
din->set_abstract(cnode->input(i)->abstract()->Broaden());
|
din->set_abstract(cnode->input(i)->abstract()->Broaden());
|
||||||
outputs->emplace_back(din);
|
(void)outputs->emplace_back(din);
|
||||||
(void)abs.emplace_back(din->abstract());
|
(void)abs.emplace_back(din->abstract());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
output->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
|
output->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs) {
|
void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs) {
|
||||||
|
@ -760,11 +794,7 @@ void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<
|
||||||
}
|
}
|
||||||
|
|
||||||
void AutoGradCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg) {
|
void AutoGradCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg) {
|
||||||
MS_EXCEPTION_IF_NULL(last_node_);
|
|
||||||
MS_LOG(DEBUG) << "Last node info " << last_node_->DebugString();
|
|
||||||
|
|
||||||
BuildForwardLastNode();
|
BuildForwardLastNode();
|
||||||
|
|
||||||
// Add sens parameter
|
// Add sens parameter
|
||||||
ParameterPtr sens_param = nullptr;
|
ParameterPtr sens_param = nullptr;
|
||||||
if (has_sens_arg) {
|
if (has_sens_arg) {
|
||||||
|
@ -774,8 +804,9 @@ void AutoGradCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool has
|
||||||
}
|
}
|
||||||
|
|
||||||
// update dout for dout
|
// update dout for dout
|
||||||
|
MS_EXCEPTION_IF_NULL(last_node_);
|
||||||
if (anfnode_to_variable_adjoint_.find(last_node_) != anfnode_to_variable_adjoint_.end()) {
|
if (anfnode_to_variable_adjoint_.find(last_node_) != anfnode_to_variable_adjoint_.end()) {
|
||||||
auto variable = anfnode_to_variable_adjoint_.at(last_node_);
|
const auto &variable = anfnode_to_variable_adjoint_.at(last_node_);
|
||||||
if (has_sens_arg && sens_param != nullptr) {
|
if (has_sens_arg && sens_param != nullptr) {
|
||||||
variable->fn()->UpdateAccumulativeDout(sens_param);
|
variable->fn()->UpdateAccumulativeDout(sens_param);
|
||||||
} else {
|
} else {
|
||||||
|
@ -787,19 +818,19 @@ void AutoGradCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool has
|
||||||
need_grad_weights_.clear();
|
need_grad_weights_.clear();
|
||||||
for (const auto &weight : weights) {
|
for (const auto &weight : weights) {
|
||||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(weight->debug_info()));
|
TraceGuard trace_guard(std::make_shared<TraceCopy>(weight->debug_info()));
|
||||||
auto p = tape_->add_parameter();
|
|
||||||
(void)need_grad_weights_.emplace(weight);
|
|
||||||
auto input_w = weight->cast<ParameterPtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(input_w);
|
|
||||||
// Use name to match weight parameter in high order
|
// Use name to match weight parameter in high order
|
||||||
auto default_param = input_w->default_param();
|
auto t = pynative::PyNativeAlgo::Common::GetTensorFromParam(weight);
|
||||||
p->set_name(input_w->name());
|
(void)need_grad_weights_.emplace(t->id());
|
||||||
p->set_default_param(default_param);
|
auto p = tape_->add_parameter();
|
||||||
p->set_abstract(default_param->ToAbstract()->Broaden());
|
auto param = weight->cast<ParameterPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(param);
|
||||||
|
p->set_name(param->name());
|
||||||
|
p->set_default_param(t);
|
||||||
|
p->set_abstract(t->ToAbstract()->Broaden());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
OrderedMap<AnfNodePtr, VariableNodePtr>::reverse_iterator AutoGradCellImpl::GetLastNodeReverseIter() {
|
OrderedMap<AnfNodePtr, VariableAdjointPtr>::reverse_iterator AutoGradCellImpl::GetLastNodeReverseIter() {
|
||||||
for (auto iter = anfnode_to_variable_adjoint_.rbegin(); iter != anfnode_to_variable_adjoint_.rend(); ++iter) {
|
for (auto iter = anfnode_to_variable_adjoint_.rbegin(); iter != anfnode_to_variable_adjoint_.rend(); ++iter) {
|
||||||
if (!iter->first->isa<CNode>()) {
|
if (!iter->first->isa<CNode>()) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -815,25 +846,30 @@ OrderedMap<AnfNodePtr, VariableNodePtr>::reverse_iterator AutoGradCellImpl::GetL
|
||||||
|
|
||||||
void AutoGradCellImpl::BackPropagate() {
|
void AutoGradCellImpl::BackPropagate() {
|
||||||
const auto &last_node_reverse_iter = GetLastNodeReverseIter();
|
const auto &last_node_reverse_iter = GetLastNodeReverseIter();
|
||||||
|
bool has_primc = false;
|
||||||
for (auto iter = last_node_reverse_iter; iter != anfnode_to_variable_adjoint_.rend(); ++iter) {
|
for (auto iter = last_node_reverse_iter; iter != anfnode_to_variable_adjoint_.rend(); ++iter) {
|
||||||
MS_LOG(DEBUG) << "BackPropagate cnode: " << iter->first->DebugString();
|
MS_LOG(DEBUG) << "BackPropagate cnode: " << iter->first->DebugString();
|
||||||
auto variable = iter->second;
|
const auto &variable = iter->second;
|
||||||
if (!variable->is_need_propagate()) {
|
if (!variable->is_need_propagate()) {
|
||||||
|
MS_LOG(DEBUG) << "No need grad";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (variable->is_need_propagate() && variable->is_fake_bprop()) {
|
if (variable->is_fake_bprop()) {
|
||||||
MS_LOG(EXCEPTION) << variable->fake_prim_name() << " op has not corresponding bprop!";
|
MS_LOG(EXCEPTION) << variable->fake_prim_name() << " op has not corresponding bprop!";
|
||||||
}
|
}
|
||||||
auto fn = variable->fn();
|
if (!has_primc && iter->first->isa<CNode>() && GetCNodePrimitive(iter->first) != nullptr) {
|
||||||
|
has_primc = true;
|
||||||
|
}
|
||||||
|
const auto &fn = variable->fn();
|
||||||
// replace real dout to fake dout
|
// replace real dout to fake dout
|
||||||
Replace(fn->fake_dout(), fn->RealDout());
|
Replace(fn->fake_dout(), fn->RealDout());
|
||||||
// replace edges which exist fake dout
|
// replace edges which exist fake dout
|
||||||
fn->ReplaceEdges();
|
fn->ReplaceEdges();
|
||||||
|
|
||||||
auto &next_edges = fn->next_edges();
|
const auto &next_edges = fn->next_edges();
|
||||||
for (const auto &next_edge : next_edges) {
|
for (const auto &next_edge : next_edges) {
|
||||||
auto node = next_edge.first;
|
const auto &node = next_edge.first;
|
||||||
auto din = next_edge.second;
|
const auto &din = next_edge.second;
|
||||||
if (anfnode_to_variable_adjoint_.find(node) == anfnode_to_variable_adjoint_.end()) {
|
if (anfnode_to_variable_adjoint_.find(node) == anfnode_to_variable_adjoint_.end()) {
|
||||||
MS_LOG(EXCEPTION) << "current node not find corresponding node";
|
MS_LOG(EXCEPTION) << "current node not find corresponding node";
|
||||||
}
|
}
|
||||||
|
@ -842,6 +878,7 @@ void AutoGradCellImpl::BackPropagate() {
|
||||||
last_variable->set_is_need_propagate(true);
|
last_variable->set_is_need_propagate(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
tape_->set_flag(kPrimCPrimPyMixed, has_primc && need_do_manager_replace_);
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr AutoGradCellImpl::GetGradNodeByIndex(const AnfNodePtrList &node_list, size_t index) {
|
AnfNodePtr AutoGradCellImpl::GetGradNodeByIndex(const AnfNodePtrList &node_list, size_t index) {
|
||||||
|
@ -925,34 +962,6 @@ AnfNodePtr AutoGradCellImpl::GetWeightGrad(bool grad_weights, const AnfNodePtrLi
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AutoGradCellImpl::IsOutputBothEmpty(const AnfNodePtr &inputs_grad, const AnfNodePtr &weights_grad) const {
|
|
||||||
if (!inputs_grad->isa<CNode>() || !weights_grad->isa<CNode>()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto inputs_grad_cnode = inputs_grad->cast<CNodePtr>();
|
|
||||||
auto weights_grad_cnode = weights_grad->cast<CNodePtr>();
|
|
||||||
if (!IsPrimitiveCNode(inputs_grad_cnode, prim::kPrimMakeTuple) ||
|
|
||||||
!IsPrimitiveCNode(weights_grad_cnode, prim::kPrimMakeTuple)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
constexpr int kEmptyTupeSize = 1;
|
|
||||||
if (inputs_grad_cnode->size() != kEmptyTupeSize || weights_grad_cnode->size() != kEmptyTupeSize) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
AnfNodePtr AutoGradCellImpl::GenerateEmptyTupleValue() {
|
|
||||||
std::vector<ValuePtr> value_list;
|
|
||||||
auto inputs_value = std::make_shared<ValueTuple>(value_list);
|
|
||||||
auto weights_value = std::make_shared<ValueTuple>(value_list);
|
|
||||||
std::vector<ValuePtr> tuple_list{inputs_value, weights_value};
|
|
||||||
auto tuple_value = std::make_shared<ValueTuple>(tuple_list);
|
|
||||||
auto tuple_value_node = NewValueNode(tuple_value);
|
|
||||||
tuple_value_node->set_abstract(tuple_value->ToAbstract());
|
|
||||||
return tuple_value_node;
|
|
||||||
}
|
|
||||||
|
|
||||||
void AutoGradCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
|
void AutoGradCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
|
||||||
const GradAttr &grad_attr) {
|
const GradAttr &grad_attr) {
|
||||||
auto inputs_grad_ret = GetInputGrad(grad_attr.grad_all_inputs, grad_attr.get_by_position, grad_position);
|
auto inputs_grad_ret = GetInputGrad(grad_attr.grad_all_inputs, grad_attr.get_by_position, grad_position);
|
||||||
|
@ -1017,8 +1026,8 @@ void AutoGradCellImpl::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new
|
||||||
}
|
}
|
||||||
|
|
||||||
void AutoGradCellImpl::ElimateTupleGetItem() {
|
void AutoGradCellImpl::ElimateTupleGetItem() {
|
||||||
for (auto iter = users_.begin(); iter != users_.end(); iter++) {
|
for (auto &user : users_) {
|
||||||
auto old_node = iter->first;
|
auto old_node = user.first;
|
||||||
if (!old_node->isa<CNode>()) {
|
if (!old_node->isa<CNode>()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -1039,6 +1048,7 @@ void AutoGradCellImpl::ElimateTupleGetItem() {
|
||||||
void AutoGradCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg) {
|
void AutoGradCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg) {
|
||||||
const auto ¶meters = tape_->parameters();
|
const auto ¶meters = tape_->parameters();
|
||||||
auto cell_inputs_size = cell_inputs_.size();
|
auto cell_inputs_size = cell_inputs_.size();
|
||||||
|
pynative::PyNativeAlgo::Common::DumpGraphIR("replace_param.ir", tape_);
|
||||||
if (need_do_manager_replace_) {
|
if (need_do_manager_replace_) {
|
||||||
MS_LOG(DEBUG) << "Do parameter replace by manager";
|
MS_LOG(DEBUG) << "Do parameter replace by manager";
|
||||||
auto mng = MakeManager({tape_}, false);
|
auto mng = MakeManager({tape_}, false);
|
||||||
|
@ -1070,13 +1080,12 @@ void AutoGradCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, boo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &weight : weights_) {
|
for (auto &weight : weights_used_in_graph_) {
|
||||||
if (need_grad_weights_.find(weight) == need_grad_weights_.end()) {
|
auto t = pynative::PyNativeAlgo::Common::GetTensorFromParam(weight);
|
||||||
auto parameter = weight->cast<ParameterPtr>();
|
if (need_grad_weights_.find(t->id()) == need_grad_weights_.end()) {
|
||||||
const auto &input_value = parameter->default_param();
|
MS_LOG(DEBUG) << "Convert " << weight->DebugString() << " to value node";
|
||||||
MS_EXCEPTION_IF_NULL(input_value);
|
auto value_node = NewValueNode(t);
|
||||||
auto value_node = NewValueNode(input_value);
|
value_node->set_abstract(t->ToAbstract()->Broaden());
|
||||||
value_node->set_abstract(input_value->ToAbstract()->Broaden());
|
|
||||||
Replace(weight, value_node);
|
Replace(weight, value_node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,9 +86,9 @@ class FunctionNode {
|
||||||
};
|
};
|
||||||
using FunctionNodePtr = std::shared_ptr<FunctionNode>;
|
using FunctionNodePtr = std::shared_ptr<FunctionNode>;
|
||||||
|
|
||||||
class VariableNode {
|
class VariableAdjoint {
|
||||||
public:
|
public:
|
||||||
VariableNode(const FunctionNodePtr &fn, const ValuePtr &out_value) : fn_(fn), out_value_(out_value) {}
|
VariableAdjoint(const FunctionNodePtr &fn, const ValuePtr &out_value) : fn_(fn), out_value_(out_value) {}
|
||||||
|
|
||||||
ValuePtr out_value() const { return out_value_; }
|
ValuePtr out_value() const { return out_value_; }
|
||||||
FunctionNodePtr fn() const { return fn_; }
|
FunctionNodePtr fn() const { return fn_; }
|
||||||
|
@ -114,7 +114,7 @@ class VariableNode {
|
||||||
// K mapped cnode for primal CNode; primal CNode is owned by primal funcgraph, this is owned by tape_;
|
// K mapped cnode for primal CNode; primal CNode is owned by primal funcgraph, this is owned by tape_;
|
||||||
AnfNodePtr k_node_{nullptr};
|
AnfNodePtr k_node_{nullptr};
|
||||||
};
|
};
|
||||||
using VariableNodePtr = std::shared_ptr<VariableNode>;
|
using VariableAdjointPtr = std::shared_ptr<VariableAdjoint>;
|
||||||
|
|
||||||
class AutoGradCellImpl {
|
class AutoGradCellImpl {
|
||||||
public:
|
public:
|
||||||
|
@ -143,10 +143,10 @@ class AutoGradCellImpl {
|
||||||
// Top cell inputs
|
// Top cell inputs
|
||||||
AnfNodePtrList cell_inputs_;
|
AnfNodePtrList cell_inputs_;
|
||||||
// These weights need to calculate gradient.
|
// These weights need to calculate gradient.
|
||||||
mindspore::HashSet<AnfNodePtr> need_grad_weights_;
|
mindspore::HashSet<std::string> need_grad_weights_;
|
||||||
// Bprop dins of each variable or middle out
|
// Bprop dins of each variable or middle out
|
||||||
OrderedMap<AnfNodePtr, VariableNodePtr> anfnode_to_variable_adjoint_;
|
OrderedMap<AnfNodePtr, VariableAdjointPtr> anfnode_to_variable_adjoint_;
|
||||||
AnfNodePtrList weights_;
|
AnfNodePtrList weights_used_in_graph_;
|
||||||
// Record cnode's input map for tape_
|
// Record cnode's input map for tape_
|
||||||
UserType users_;
|
UserType users_;
|
||||||
// Flag for ms_funtcion and high order
|
// Flag for ms_funtcion and high order
|
||||||
|
@ -156,7 +156,8 @@ class AutoGradCellImpl {
|
||||||
std::vector<bool> GetNeedGradFlags(const CNodePtr &cnode);
|
std::vector<bool> GetNeedGradFlags(const CNodePtr &cnode);
|
||||||
|
|
||||||
// construct input as cnode for expander
|
// construct input as cnode for expander
|
||||||
CNodePtr ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout);
|
CNodePtr ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout,
|
||||||
|
const VariableAdjointPtr &variable_adjoint);
|
||||||
// Back propagate for one node;
|
// Back propagate for one node;
|
||||||
void UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr &cnode, const std::vector<CNodePtr> &dins,
|
void UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr &cnode, const std::vector<CNodePtr> &dins,
|
||||||
const ValuePtrList &op_args);
|
const ValuePtrList &op_args);
|
||||||
|
@ -176,7 +177,7 @@ class AutoGradCellImpl {
|
||||||
// Set sens and weights parameter nodes by user input info
|
// Set sens and weights parameter nodes by user input info
|
||||||
void SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg);
|
void SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg);
|
||||||
// get last reverse iterator
|
// get last reverse iterator
|
||||||
OrderedMap<AnfNodePtr, VariableNodePtr>::reverse_iterator GetLastNodeReverseIter();
|
OrderedMap<AnfNodePtr, VariableAdjointPtr>::reverse_iterator GetLastNodeReverseIter();
|
||||||
|
|
||||||
void BackPropagate();
|
void BackPropagate();
|
||||||
// Set return node according to grad flag
|
// Set return node according to grad flag
|
||||||
|
@ -184,14 +185,12 @@ class AutoGradCellImpl {
|
||||||
AnfNodePtr GetGradNodeByIndex(const AnfNodePtrList &node_list, size_t index);
|
AnfNodePtr GetGradNodeByIndex(const AnfNodePtrList &node_list, size_t index);
|
||||||
AnfNodePtr GetInputGrad(bool grad_all_inputs, bool get_by_position, const std::vector<size_t> &grad_position);
|
AnfNodePtr GetInputGrad(bool grad_all_inputs, bool get_by_position, const std::vector<size_t> &grad_position);
|
||||||
AnfNodePtr GetWeightGrad(bool grad_weights, const AnfNodePtrList &weights, bool weight_param_is_tuple);
|
AnfNodePtr GetWeightGrad(bool grad_weights, const AnfNodePtrList &weights, bool weight_param_is_tuple);
|
||||||
bool IsOutputBothEmpty(const AnfNodePtr &inputs_grad, const AnfNodePtr &weights_grad) const;
|
|
||||||
AnfNodePtr GenerateEmptyTupleValue();
|
|
||||||
void AddUser(const AnfNodePtr &node, const CNodePtr &user, size_t index);
|
void AddUser(const AnfNodePtr &node, const CNodePtr &user, size_t index);
|
||||||
void Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
|
void Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
|
||||||
void ElimateTupleGetItem();
|
void ElimateTupleGetItem();
|
||||||
|
|
||||||
// Fbprop
|
// Fbprop
|
||||||
void BuildKNode(const GradParamPtr &grad_param, const VariableNodePtr &VariableNode);
|
void BuildKNode(const GradParamPtr &grad_param, const VariableAdjointPtr &variable_adjoint);
|
||||||
void BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const ValuePtrList &op_args,
|
void BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const ValuePtrList &op_args,
|
||||||
std::vector<AnfNodePtr> *const node_list);
|
std::vector<AnfNodePtr> *const node_list);
|
||||||
AnfNodePtr BuildKNodeForCNodeInput(const AnfNodePtr &input_node);
|
AnfNodePtr BuildKNodeForCNodeInput(const AnfNodePtr &input_node);
|
||||||
|
|
|
@ -29,40 +29,32 @@ using mindspore::tensor::Tensor;
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
std::string GetOpPythonPath(const OperatorName &op_name) {
|
const char *GetOpPythonPath(const char *op_name) {
|
||||||
// almost all ops are defined in two main paths
|
static py::module inner_mod = py::module::import(INNER_OP_PATH);
|
||||||
const std::string ops_module = OP_PATH;
|
if (py::hasattr(inner_mod, op_name)) {
|
||||||
const std::string inner_ops_module = INNER_OP_PATH;
|
return INNER_OP_PATH;
|
||||||
const std::string grad_ops_module = GRAD_OP_PATH;
|
}
|
||||||
const std::string functional_op_module = FUNCTIONAL_OP_PATH;
|
|
||||||
py::module mod = py::module::import(common::SafeCStr(ops_module));
|
|
||||||
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
|
|
||||||
py::module grad_mod = py::module::import(common::SafeCStr(grad_ops_module));
|
|
||||||
py::module functional_mod = py::module::import(common::SafeCStr(functional_op_module));
|
|
||||||
|
|
||||||
if (py::hasattr(inner_mod, common::SafeCStr(op_name))) {
|
static py::module mod = py::module::import(OP_PATH);
|
||||||
return inner_ops_module;
|
if (py::hasattr(mod, op_name)) {
|
||||||
|
return OP_PATH;
|
||||||
}
|
}
|
||||||
if (py::hasattr(mod, common::SafeCStr(op_name))) {
|
|
||||||
return ops_module;
|
static py::module grad_mod = py::module::import(GRAD_OP_PATH);
|
||||||
|
if (py::hasattr(grad_mod, op_name)) {
|
||||||
|
return GRAD_OP_PATH;
|
||||||
}
|
}
|
||||||
if (py::hasattr(grad_mod, common::SafeCStr(op_name))) {
|
|
||||||
return grad_ops_module;
|
static py::module functional_mod = py::module::import(FUNCTIONAL_OP_PATH);
|
||||||
|
if (!py::hasattr(functional_mod, op_name)) {
|
||||||
|
MS_LOG(EXCEPTION) << OP_PATH << " and " << INNER_OP_PATH << " and " << GRAD_OP_PATH << " and " << FUNCTIONAL_OP_PATH
|
||||||
|
<< " don't have op:" << op_name;
|
||||||
}
|
}
|
||||||
if (!py::hasattr(functional_mod, common::SafeCStr(op_name))) {
|
return FUNCTIONAL_OP_PATH;
|
||||||
MS_LOG(EXCEPTION) << ops_module << " and " << inner_ops_module << " and " << grad_ops_module << " and "
|
|
||||||
<< functional_op_module << " don't have op:" << op_name;
|
|
||||||
}
|
|
||||||
return functional_op_module;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr CreateOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
|
ValuePtr CreateOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
|
||||||
std::string op_path = GetOpPythonPath(op_name);
|
const auto op_path = GetOpPythonPath(op_name.c_str());
|
||||||
py::module mod = py::module::import(common::SafeCStr(op_path));
|
|
||||||
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
|
|
||||||
MS_LOG(ERROR) << "Failure: op_path:" << op_path << " don't have attr " << op_name;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
std::vector<py::object> arg_list;
|
std::vector<py::object> arg_list;
|
||||||
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list),
|
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list),
|
||||||
[](const Attr &attr) { return ValueToPyData(attr.second); });
|
[](const Attr &attr) { return ValueToPyData(attr.second); });
|
||||||
|
|
|
@ -31,7 +31,7 @@ namespace mindspore {
|
||||||
namespace parallel {
|
namespace parallel {
|
||||||
const char USING_HASH_NAME[] = "USING_HASH_NAME";
|
const char USING_HASH_NAME[] = "USING_HASH_NAME";
|
||||||
// Get the operator's path where the operator has be defined
|
// Get the operator's path where the operator has be defined
|
||||||
std::string GetOpPythonPath(const OperatorName &op_name);
|
const char *GetOpPythonPath(const char *op_name);
|
||||||
|
|
||||||
// Init python operator Instance
|
// Init python operator Instance
|
||||||
ValuePtr CreateOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name);
|
ValuePtr CreateOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name);
|
||||||
|
|
|
@ -911,6 +911,7 @@ constexpr auto kFlagPyNativeRunInGraph = "pynative_run_in_graph";
|
||||||
constexpr auto kFlagNeedRenormalize = "need_renormalize";
|
constexpr auto kFlagNeedRenormalize = "need_renormalize";
|
||||||
constexpr auto kFlagEnableZeroCopyInGraph = "enable_zero_copy_in_graph";
|
constexpr auto kFlagEnableZeroCopyInGraph = "enable_zero_copy_in_graph";
|
||||||
constexpr auto kFlagUseDynamicShapeProcess = "use_dynamic_shape_process";
|
constexpr auto kFlagUseDynamicShapeProcess = "use_dynamic_shape_process";
|
||||||
|
constexpr auto kPrimCPrimPyMixed = "primc_primpy_mixed";
|
||||||
// TODO(dsj): for ms_function running in graph_mode. should be delete later
|
// TODO(dsj): for ms_function running in graph_mode. should be delete later
|
||||||
constexpr auto kAttrMSFunction = "ms_function_graph";
|
constexpr auto kAttrMSFunction = "ms_function_graph";
|
||||||
|
|
||||||
|
|
|
@ -88,12 +88,16 @@ struct InputArgsInfo {
|
||||||
bool has_custom_bprop;
|
bool has_custom_bprop;
|
||||||
size_t input_size;
|
size_t input_size;
|
||||||
std::string obj_id;
|
std::string obj_id;
|
||||||
|
|
||||||
bool has_sens{false};
|
bool has_sens{false};
|
||||||
bool use_dynamic_shape_process = false;
|
bool use_dynamic_shape_process = false;
|
||||||
PrimitivePyPtr custom_bprp_prim{nullptr};
|
PrimitivePyPtr custom_bprp_prim{nullptr};
|
||||||
ValuePtr out_value{nullptr};
|
ValuePtr out_value{nullptr};
|
||||||
std::string cell_id;
|
std::string cell_id;
|
||||||
|
std::string already_run_cell_id;
|
||||||
std::string input_args_id;
|
std::string input_args_id;
|
||||||
|
// Cell unique id, cell_id + cell_order;
|
||||||
|
std::string obj_order_id;
|
||||||
size_t custom_bprop_cell_count = 0;
|
size_t custom_bprop_cell_count = 0;
|
||||||
size_t grad_order = 0;
|
size_t grad_order = 0;
|
||||||
std::vector<std::string> input_arg_id_vec;
|
std::vector<std::string> input_arg_id_vec;
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include "pipeline/pynative/pynative_utils.h"
|
#include "pipeline/pynative/pynative_utils.h"
|
||||||
#include "pipeline/jit/pipeline.h"
|
#include "pipeline/jit/pipeline.h"
|
||||||
#include "ir/cell.h"
|
#include "ir/cell.h"
|
||||||
|
#include "ir/func_graph_cloner.h"
|
||||||
#include "pipeline/jit/parse/data_converter.h"
|
#include "pipeline/jit/parse/data_converter.h"
|
||||||
#include "pipeline/jit/debug/trace.h"
|
#include "pipeline/jit/debug/trace.h"
|
||||||
#include "backend/common/optimizer/helper.h"
|
#include "backend/common/optimizer/helper.h"
|
||||||
|
@ -71,8 +72,8 @@ std::string GetFnInfoByPyObj(const py::object &obj) {
|
||||||
return (module_name + "_" + fn_name + "_" + filename + "_" + code_lineno);
|
return (module_name + "_" + fn_name + "_" + filename + "_" + code_lineno);
|
||||||
}
|
}
|
||||||
|
|
||||||
InputArgsInfoPtr ParsePyArgsToInputArgsInfo(const py::object &obj, const py::args &args, bool is_grad_top_cell,
|
InputArgsInfoPtr ParsePyArgsToInputArgsInfo(const py::object &obj, const py::args &args, size_t obj_order,
|
||||||
bool is_high_order_top_cell) {
|
bool is_grad_top_cell, bool is_high_order_top_cell) {
|
||||||
bool has_custom_bprop = py::hasattr(obj, parse::CUSTOM_BPROP_NAME);
|
bool has_custom_bprop = py::hasattr(obj, parse::CUSTOM_BPROP_NAME);
|
||||||
std::string obj_id;
|
std::string obj_id;
|
||||||
if (!py::isinstance<Cell>(obj) && (is_grad_top_cell || is_high_order_top_cell)) {
|
if (!py::isinstance<Cell>(obj) && (is_grad_top_cell || is_high_order_top_cell)) {
|
||||||
|
@ -97,7 +98,9 @@ InputArgsInfoPtr ParsePyArgsToInputArgsInfo(const py::object &obj, const py::arg
|
||||||
pipeline::CheckArgsValid(obj, args);
|
pipeline::CheckArgsValid(obj, args);
|
||||||
}
|
}
|
||||||
input_args_info->cell_id = GetCellId(obj, args, input_args_info);
|
input_args_info->cell_id = GetCellId(obj, args, input_args_info);
|
||||||
MS_LOG(DEBUG) << "cell_id is " << obj_id << ", is grad top cell " << (is_grad_top_cell || is_high_order_top_cell);
|
input_args_info->obj_order_id = input_args_info->cell_id + '_' + std::to_string(obj_order);
|
||||||
|
MS_LOG(DEBUG) << "Cell_id is " << input_args_info->cell_id << ", is grad top cell "
|
||||||
|
<< (is_grad_top_cell || is_high_order_top_cell);
|
||||||
return input_args_info;
|
return input_args_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -194,8 +197,8 @@ bool IsParamInfoEqual(const ParamInfoPtr &p1, const ParamInfoPtr &p2) {
|
||||||
return p1->key() == p2->key();
|
return p1->key() == p2->key();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info, const std::vector<AnfNodePtr> &new_anf_inputs,
|
bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info, size_t node_index,
|
||||||
const TopCellInfoPtr &top_cell) {
|
const std::vector<AnfNodePtr> &new_anf_inputs, const TopCellInfoPtr &top_cell) {
|
||||||
MS_EXCEPTION_IF_NULL(old_node_info);
|
MS_EXCEPTION_IF_NULL(old_node_info);
|
||||||
auto old_input_size = old_node_info->input_cnode_info.size() + old_node_info->input_values.size() +
|
auto old_input_size = old_node_info->input_cnode_info.size() + old_node_info->input_values.size() +
|
||||||
old_node_info->input_param_infos.size();
|
old_node_info->input_param_infos.size();
|
||||||
|
@ -237,9 +240,9 @@ bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info, const s
|
||||||
|
|
||||||
// Compare cnode edge.
|
// Compare cnode edge.
|
||||||
MS_EXCEPTION_IF_NULL(top_cell);
|
MS_EXCEPTION_IF_NULL(top_cell);
|
||||||
if (old_op_index != top_cell->get_op_index_by_cnode_hash(new_anf_input->hash())) {
|
if (old_op_index != top_cell->get_op_index_by_cnode_hash(new_anf_input->hash(), node_index)) {
|
||||||
MS_LOG(DEBUG) << "The " << i << "th input, op_index is different, old op_index: " << old_op_index
|
MS_LOG(DEBUG) << "The " << i << "th input, op_index is different, old op_index: " << old_op_index
|
||||||
<< " new op_index: " << top_cell->get_op_index_by_cnode_hash(new_anf_input->hash());
|
<< " new op_index: " << node_index;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -267,23 +270,11 @@ bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info, const s
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info, const CNodePtr &new_cnode,
|
bool IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info, size_t node_index,
|
||||||
bool is_ms_function_node, const std::string &graph_phase,
|
const CNodePtr &new_cnode, const TopCellInfoPtr &top_cell) {
|
||||||
const TopCellInfoPtr &top_cell) {
|
|
||||||
MS_EXCEPTION_IF_NULL(old_node_info);
|
MS_EXCEPTION_IF_NULL(old_node_info);
|
||||||
// 1.Detect ms_function phase
|
|
||||||
if (is_ms_function_node) {
|
|
||||||
if (!old_node_info->is_graph_node || graph_phase != old_node_info->graph_phase) {
|
|
||||||
MS_LOG(DEBUG) << "Graph is dynamic, old is_graph_node: " << old_node_info->is_graph_node
|
|
||||||
<< " new is_graph_node: " << is_ms_function_node << " old graph_phase "
|
|
||||||
<< old_node_info->graph_phase << " new graph_phase: " << graph_phase;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2.Detect cnode prim
|
// 2.Detect cnode prim
|
||||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
|
||||||
auto new_prim = GetCNodePrimitive(new_cnode);
|
auto new_prim = GetCNodePrimitive(new_cnode);
|
||||||
if (!common::IsEqual(new_prim, old_node_info->prim)) {
|
if (!common::IsEqual(new_prim, old_node_info->prim)) {
|
||||||
MS_LOG(DEBUG) << "Graph is dynamic, old prim: "
|
MS_LOG(DEBUG) << "Graph is dynamic, old prim: "
|
||||||
|
@ -299,10 +290,18 @@ bool IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4.Detect inputs
|
// 4.Detect inputs
|
||||||
return IsCnodeInputsDynamic(old_node_info, new_cnode->inputs(), top_cell);
|
return IsCnodeInputsDynamic(old_node_info, node_index, new_cnode->inputs(), top_cell);
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr BpropGraphFinalOpt(const FuncGraphPtr &bprop_graph) {
|
FuncGraphPtr BpropGraphFinalOpt(const FuncGraphPtr &bprop_graph, bool need_renormalize) {
|
||||||
|
MS_LOG(DEBUG) << "Do bporp graph final opt";
|
||||||
|
MS_EXCEPTION_IF_NULL(bprop_graph);
|
||||||
|
if (need_renormalize && bprop_graph->has_flag(kPrimCPrimPyMixed)) {
|
||||||
|
MS_LOG(DEBUG) << "Convert PrimitiveC to PrimitivePy";
|
||||||
|
if (!opt::ConvertPrimToPrimPy(bprop_graph)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Convert PrimitiveC to PrimitivePy failed";
|
||||||
|
}
|
||||||
|
}
|
||||||
auto resource = std::make_shared<pipeline::Resource>();
|
auto resource = std::make_shared<pipeline::Resource>();
|
||||||
resource->set_func_graph(bprop_graph);
|
resource->set_func_graph(bprop_graph);
|
||||||
auto manager = resource->manager();
|
auto manager = resource->manager();
|
||||||
|
@ -349,6 +348,24 @@ void SetGraphInputArgs(const std::vector<ValuePtr> &input_vec, const pipeline::R
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetSensValue(const prim::GradOperationPtr &grad, const InputArgsInfoPtr &input_args_info, const py::args &args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(grad);
|
||||||
|
if (!grad->sens_param()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "Get sens param";
|
||||||
|
size_t forward_args_size = args.size() - 1;
|
||||||
|
auto sens_v = PyNativeAlgo::DataConvert::PyObjToValue(args[forward_args_size]);
|
||||||
|
const auto &sens_tensor = ConvertOutputValueToTensor(sens_v);
|
||||||
|
// Sens have already exist, which may be need update
|
||||||
|
MS_EXCEPTION_IF_NULL(input_args_info);
|
||||||
|
if (input_args_info->input_arg_value_vec.size() == args.size()) {
|
||||||
|
input_args_info->input_arg_value_vec.pop_back();
|
||||||
|
}
|
||||||
|
(void)input_args_info->input_arg_value_vec.emplace_back(ShallowCopyTensorValue(sens_tensor));
|
||||||
|
input_args_info->has_sens = true;
|
||||||
|
}
|
||||||
|
|
||||||
AbstractBasePtr GetGradGraphOutputAbstract(const FuncGraphPtr &fg) {
|
AbstractBasePtr GetGradGraphOutputAbstract(const FuncGraphPtr &fg) {
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
MS_EXCEPTION_IF_NULL(fg->output());
|
MS_EXCEPTION_IF_NULL(fg->output());
|
||||||
|
@ -362,13 +379,6 @@ ForwardExecutorPtr GradExecutor::forward() const {
|
||||||
return forward_executor;
|
return forward_executor;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string GradExecutor::GetCurCellOrder() const {
|
|
||||||
if (cur_cell_id_.empty()) {
|
|
||||||
MS_LOG(EXCEPTION) << "The cur_cell_id_ is empty!";
|
|
||||||
}
|
|
||||||
return cur_cell_id_ + "_" + std::to_string(cell_order_);
|
|
||||||
}
|
|
||||||
|
|
||||||
TopCellInfoPtr GradExecutor::PopHighOrderGraphStack() {
|
TopCellInfoPtr GradExecutor::PopHighOrderGraphStack() {
|
||||||
if (high_order_stack_.empty()) {
|
if (high_order_stack_.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "Stack high_order_stack_ is empty";
|
MS_LOG(EXCEPTION) << "Stack high_order_stack_ is empty";
|
||||||
|
@ -383,7 +393,6 @@ TopCellInfoPtr GradExecutor::PopHighOrderGraphStack() {
|
||||||
|
|
||||||
void GradExecutor::PushInputArgsInfoStack(const InputArgsInfoPtr &input_args_info) {
|
void GradExecutor::PushInputArgsInfoStack(const InputArgsInfoPtr &input_args_info) {
|
||||||
input_args_info_stack_.push(input_args_info);
|
input_args_info_stack_.push(input_args_info);
|
||||||
// ++cell_order_;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void GradExecutor::PopInputArgsInfoStack() {
|
void GradExecutor::PopInputArgsInfoStack() {
|
||||||
|
@ -468,39 +477,45 @@ void GradExecutor::InitResourceAndDfBuilder(const InputArgsInfoPtr &input_args_i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GradExecutor::UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph) const {
|
|
||||||
top_cell()->set_need_compile_graph(need_compile_graph);
|
|
||||||
top_cell()->set_forward_already_run(forward_already_run);
|
|
||||||
}
|
|
||||||
|
|
||||||
void GradExecutor::NewGraphInner(const py::object &obj, const py::args &args) {
|
void GradExecutor::NewGraphInner(const py::object &obj, const py::args &args) {
|
||||||
const auto input_args_info = GetInputArgsInfo(obj, args);
|
const auto input_args_info = GetInputArgsInfo(obj, args);
|
||||||
PushInputArgsInfoStack(input_args_info);
|
PushInputArgsInfoStack(input_args_info);
|
||||||
|
|
||||||
if (input_args_info->has_custom_bprop) {
|
|
||||||
custom_bprop_cell_count_ += 1;
|
|
||||||
input_args_info->custom_bprop_cell_count = custom_bprop_cell_count_;
|
|
||||||
}
|
|
||||||
if (grad_order_ == 0) {
|
|
||||||
IncreaseGradOrder();
|
|
||||||
}
|
|
||||||
input_args_info->grad_order = grad_order_;
|
|
||||||
// May be can async here
|
// May be can async here
|
||||||
NewGraphImpl(input_args_info);
|
NewGraphImpl(input_args_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
InputArgsInfoPtr GradExecutor::GetInputArgsInfo(const py::object &obj, const py::args &args) {
|
InputArgsInfoPtr GradExecutor::GetInputArgsInfo(const py::object &obj, const py::args &args) {
|
||||||
auto input_args_info =
|
auto input_args_info =
|
||||||
ParsePyArgsToInputArgsInfo(obj, args, input_args_info_stack_.empty(), is_high_order_top_cell());
|
ParsePyArgsToInputArgsInfo(obj, args, obj_order_++, input_args_info_stack_.empty(), is_high_order_top_cell());
|
||||||
|
if (input_args_info->has_custom_bprop) {
|
||||||
|
custom_bprop_cell_count_ += 1;
|
||||||
|
input_args_info->custom_bprop_cell_count = custom_bprop_cell_count_;
|
||||||
|
}
|
||||||
|
// CheckAlready run first, grad_order_ will increase 1(highorder scenario)
|
||||||
|
// If NetA.set_grad(), so come here first, CheckAlready run later, so grad_order_ need increase 1
|
||||||
|
if (input_args_info->is_grad_topest_cell || input_args_info->is_high_order_top_cell) {
|
||||||
|
if (grad_order_ == 0) {
|
||||||
|
IncreaseGradOrder();
|
||||||
|
}
|
||||||
|
// Both set grad: NetA.set_grad(); NetB.set_grad();
|
||||||
|
// Run forward: NetA(); NetB();then grad order is 2
|
||||||
|
// Grad(NetA()); Grad(NetB()). NetA grad order now is 2. Forward run again. grad_order is disordered, so need reset.
|
||||||
|
if (input_args_info->is_grad_topest_cell && input_args_info->grad_order > 1) {
|
||||||
|
input_args_info->grad_order--;
|
||||||
|
}
|
||||||
|
input_args_info->already_run_cell_id = GetAlreadyRunCellId(input_args_info->obj_id);
|
||||||
|
}
|
||||||
|
input_args_info->grad_order = grad_order_;
|
||||||
input_args_info->use_dynamic_shape_process = use_dynamic_shape_process_;
|
input_args_info->use_dynamic_shape_process = use_dynamic_shape_process_;
|
||||||
|
// top_input_args_info_ indicate current running cell info
|
||||||
|
top_input_args_info_ = input_args_info;
|
||||||
return input_args_info;
|
return input_args_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GradExecutor::NewGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
void GradExecutor::NewGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
||||||
MS_EXCEPTION_IF_NULL(input_args_info);
|
MS_EXCEPTION_IF_NULL(input_args_info);
|
||||||
++cell_order_;
|
|
||||||
const auto &cell_id = input_args_info->cell_id;
|
const auto &cell_id = input_args_info->cell_id;
|
||||||
cur_cell_id_ = cell_id;
|
|
||||||
MS_LOG(DEBUG) << "NewGraphInner start " << input_args_info->input_size << ", cell_id " << cell_id
|
MS_LOG(DEBUG) << "NewGraphInner start " << input_args_info->input_size << ", cell_id " << cell_id
|
||||||
<< ", input args info ptr " << input_args_info.get();
|
<< ", input args info ptr " << input_args_info.get();
|
||||||
// Make top graph and init resource
|
// Make top graph and init resource
|
||||||
|
@ -514,34 +529,21 @@ void GradExecutor::AsyncNewGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void GradExecutor::MakeNewTopGraph(const InputArgsInfoPtr &input_args_info) {
|
void GradExecutor::MakeNewTopGraph(const InputArgsInfoPtr &input_args_info) {
|
||||||
MS_EXCEPTION_IF_NULL(input_args_info);
|
|
||||||
// CheckAlready run first, grad_order_ will increase 1(highorder scenario)
|
|
||||||
// If NetA.set_grad(), so come here first, CheckAlready run later, so grad_order_ need increase 1
|
|
||||||
if (input_args_info->grad_order == 0) {
|
|
||||||
input_args_info->grad_order++;
|
|
||||||
}
|
|
||||||
// Both set grad: NetA.set_grad(); NetB.set_grad();
|
|
||||||
// Run forward: NetA(); NetB();
|
|
||||||
// Grad(NetA()); Grad(NetB()). grad_order_ is disordered, so need reset.
|
|
||||||
if (input_args_info->is_grad_topest_cell && input_args_info->grad_order > 1) {
|
|
||||||
input_args_info->grad_order--;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto fg = std::make_shared<FuncGraph>();
|
auto fg = std::make_shared<FuncGraph>();
|
||||||
fg->debug_info()->set_name("pynative_forward_graph");
|
fg->debug_info()->set_name("pynative_forward_graph");
|
||||||
auto resource = std::make_shared<pipeline::Resource>();
|
auto resource = std::make_shared<pipeline::Resource>();
|
||||||
const auto &already_run_cell_id = GetAlreadyRunCellId(input_args_info->obj_id);
|
MS_EXCEPTION_IF_NULL(input_args_info);
|
||||||
top_cell_ =
|
const auto &obj_id_with_grad_order = input_args_info->obj_id + "_" + std::to_string(input_args_info->grad_order);
|
||||||
std::make_shared<TopCellInfo>(input_args_info->is_high_order_top_cell, input_args_info->grad_order,
|
top_cell_ = std::make_shared<TopCellInfo>(input_args_info->is_high_order_top_cell, input_args_info->grad_order,
|
||||||
input_args_info->obj_id, input_args_info->cell_id, already_run_cell_id, resource, fg);
|
obj_id_with_grad_order, input_args_info->cell_id,
|
||||||
|
input_args_info->already_run_cell_id, resource, fg);
|
||||||
top_cell_->set_forward_already_run(true);
|
top_cell_->set_forward_already_run(true);
|
||||||
top_cell_->set_input_args_id(input_args_info->input_args_id);
|
top_cell_->set_input_args_id(input_args_info->input_args_id);
|
||||||
PushHighOrderGraphStack(top_cell_);
|
PushHighOrderGraphStack(top_cell_);
|
||||||
(void)top_cell_list_.emplace_back(top_cell_);
|
(void)top_cell_list_.emplace_back(top_cell_);
|
||||||
|
|
||||||
const auto &cell_id = input_args_info->obj_id.append("_").append(std::to_string(grad_order_));
|
|
||||||
is_cell_id_in_dynamic_detect_nodes_map_ =
|
is_cell_id_in_dynamic_detect_nodes_map_ =
|
||||||
(cell_id_with_dynamic_detect_nodes_.find(cell_id) != cell_id_with_dynamic_detect_nodes_.end());
|
(cell_id_with_dynamic_detect_nodes_.find(obj_id_with_grad_order) != cell_id_with_dynamic_detect_nodes_.end());
|
||||||
MS_LOG(DEBUG) << "New top graph, fg ptr " << fg.get() << " resource ptr " << resource.get();
|
MS_LOG(DEBUG) << "New top graph, fg ptr " << fg.get() << " resource ptr " << resource.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -578,25 +580,34 @@ void GradExecutor::EndGraphInner(const py::object &obj, const py::object &out, c
|
||||||
}
|
}
|
||||||
const auto input_args_info = input_args_info_stack_.top();
|
const auto input_args_info = input_args_info_stack_.top();
|
||||||
MS_EXCEPTION_IF_NULL(input_args_info);
|
MS_EXCEPTION_IF_NULL(input_args_info);
|
||||||
if (input_args_info->has_custom_bprop) {
|
UpdateInputArgsInfo(input_args_info, obj, out, args);
|
||||||
GetCustomBpropPrim(obj, args, out, input_args_info);
|
|
||||||
}
|
|
||||||
input_args_info->out_value = PyNativeAlgo::DataConvert::PyObjToValue(out);
|
|
||||||
input_args_info->use_dynamic_shape_process = use_dynamic_shape_process_;
|
|
||||||
PopInputArgsInfoStack();
|
PopInputArgsInfoStack();
|
||||||
if (input_args_info->is_grad_topest_cell) {
|
|
||||||
set_grad_flag(false);
|
|
||||||
}
|
|
||||||
// May be can async here
|
// May be can async here
|
||||||
EndGraphImpl(input_args_info);
|
EndGraphImpl(input_args_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GradExecutor::UpdateInputArgsInfo(const InputArgsInfoPtr &input_args_info, const py::object &obj,
|
||||||
|
const py::object &out, const py::args &args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(input_args_info);
|
||||||
|
if (input_args_info->has_custom_bprop) {
|
||||||
|
GetCustomBpropPrim(obj, args, out, input_args_info);
|
||||||
|
}
|
||||||
|
// Used at master thread, change its at master thread
|
||||||
|
if (input_args_info->is_grad_topest_cell) {
|
||||||
|
grad_flag_ = false;
|
||||||
|
obj_order_ = 0;
|
||||||
|
}
|
||||||
|
input_args_info->out_value = PyNativeAlgo::DataConvert::PyObjToValue(out);
|
||||||
|
// If use_dynamic_shape_process_ update in run op process, here can instantly sensed
|
||||||
|
input_args_info->use_dynamic_shape_process = use_dynamic_shape_process_;
|
||||||
|
}
|
||||||
|
|
||||||
void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
||||||
MS_EXCEPTION_IF_NULL(input_args_info);
|
MS_EXCEPTION_IF_NULL(input_args_info);
|
||||||
const auto &cell_id = input_args_info->cell_id;
|
MS_LOG(DEBUG) << "EndGraphInner start " << input_args_info->input_size << ", cell_id " << input_args_info->cell_id
|
||||||
MS_LOG(DEBUG) << "EndGraphInner start " << input_args_info->input_size << ", cell_id " << cell_id
|
|
||||||
<< ", input args info ptr " << input_args_info.get();
|
<< ", input args info ptr " << input_args_info.get();
|
||||||
bool is_top_cell_end = (cell_id == top_cell()->cell_id());
|
bool is_top_cell_end = (input_args_info->cell_id == top_cell()->cell_id());
|
||||||
if (is_top_cell_end) {
|
if (is_top_cell_end) {
|
||||||
input_args_info->out_value = ConvertOutputValueToTensor(input_args_info->out_value);
|
input_args_info->out_value = ConvertOutputValueToTensor(input_args_info->out_value);
|
||||||
}
|
}
|
||||||
|
@ -619,14 +630,12 @@ void GradExecutor::EndGraphImpl(const InputArgsInfoPtr &input_args_info) {
|
||||||
}
|
}
|
||||||
// Reset grad flag and update output node of the outermost cell
|
// Reset grad flag and update output node of the outermost cell
|
||||||
if (input_args_info->is_grad_topest_cell && is_top_cell_end) {
|
if (input_args_info->is_grad_topest_cell && is_top_cell_end) {
|
||||||
MS_LOG(DEBUG) << "Cur top last cell " << cell_id;
|
MS_LOG(DEBUG) << "Cur top last cell " << input_args_info->cell_id;
|
||||||
(void)PopHighOrderGraphStack();
|
(void)PopHighOrderGraphStack();
|
||||||
SetForwardLastNodeInfo(input_args_info->out_value, out_id);
|
SetForwardLastNodeInfo(input_args_info->out_value, out_id);
|
||||||
top_cell()->ClearCellHookOp();
|
top_cell()->ClearCellHookOp();
|
||||||
cell_order_ = 0;
|
|
||||||
// set_grad_flag(false);
|
|
||||||
}
|
}
|
||||||
// Checkout whether need to compile graph when each top cell has ran finished
|
// Checkout whether need to compile graph when each top cell has run finished
|
||||||
if (is_top_cell_end) {
|
if (is_top_cell_end) {
|
||||||
// In high grad cases, the output of the internal graph may be a tuple, and node needs to be created in the getobj
|
// In high grad cases, the output of the internal graph may be a tuple, and node needs to be created in the getobj
|
||||||
if (!input_args_info->is_grad_topest_cell) {
|
if (!input_args_info->is_grad_topest_cell) {
|
||||||
|
@ -731,7 +740,6 @@ void GradExecutor::CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info
|
||||||
if (already_run_top_cell_.find(already_top_cell_id) == already_run_top_cell_.end()) {
|
if (already_run_top_cell_.find(already_top_cell_id) == already_run_top_cell_.end()) {
|
||||||
MS_LOG(DEBUG) << "Cell " << already_top_cell_id << " has never been ran, need compile graph";
|
MS_LOG(DEBUG) << "Cell " << already_top_cell_id << " has never been ran, need compile graph";
|
||||||
already_run_top_cell_[already_top_cell_id] = new_top_cell;
|
already_run_top_cell_[already_top_cell_id] = new_top_cell;
|
||||||
pre_top_cell_ = top_cell();
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -742,30 +750,27 @@ void GradExecutor::CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info
|
||||||
if (input_args_info->use_dynamic_shape_process) {
|
if (input_args_info->use_dynamic_shape_process) {
|
||||||
// Function need compile every time.
|
// Function need compile every time.
|
||||||
MS_LOG(DEBUG) << "The graph is dynamic, need to compile graph again";
|
MS_LOG(DEBUG) << "The graph is dynamic, need to compile graph again";
|
||||||
EraseTopCellFromTopCellList(pre_top_cell);
|
|
||||||
{
|
{
|
||||||
py::gil_scoped_acquire acquire;
|
py::gil_scoped_acquire acquire;
|
||||||
pre_top_cell->Clear();
|
EraseTopCellFromTopCellList(pre_top_cell);
|
||||||
}
|
}
|
||||||
already_run_top_cell_[already_top_cell_id] = new_top_cell;
|
already_run_top_cell_[already_top_cell_id] = new_top_cell;
|
||||||
pre_top_cell_ = nullptr;
|
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "no need to compile graph again";
|
MS_LOG(DEBUG) << "No need to compile graph again";
|
||||||
pre_top_cell->set_input_args_id(new_top_cell->input_args_id());
|
pre_top_cell->set_input_args_id(new_top_cell->input_args_id());
|
||||||
// In high order situations, the internal top cell remains unchanged, but the external top cell has changed. Then
|
// In high order situations, the internal top cell remains unchanged, but the external top cell has changed. Then
|
||||||
// the graph info of the internal top cell needs to be updated so that the external top cell can perceive it.
|
// the graph info of the internal top cell needs to be updated so that the external top cell can perceive it.
|
||||||
if (!input_args_info->is_grad_topest_cell) {
|
if (!input_args_info->is_grad_topest_cell) {
|
||||||
pre_top_cell->SetGraphInfoMap(pre_top_cell->fg(), new_top_cell->graph_info_map().at(new_top_cell->fg()));
|
pre_top_cell->SetGraphInfoMap(pre_top_cell->fg(), new_top_cell->graph_info_map().at(new_top_cell->fg()));
|
||||||
}
|
}
|
||||||
pre_top_cell_ = pre_top_cell;
|
|
||||||
pre_top_cell->set_forward_already_run(true);
|
pre_top_cell->set_forward_already_run(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GradExecutor::EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell) {
|
void GradExecutor::EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell) {
|
||||||
MS_EXCEPTION_IF_NULL(top_cell);
|
MS_EXCEPTION_IF_NULL(top_cell);
|
||||||
auto iter = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
|
const auto iter = std::find_if(top_cell_list_.begin(), top_cell_list_.end(),
|
||||||
[&](const TopCellInfoPtr &elem) { return elem.get() == top_cell.get(); });
|
[&](const TopCellInfoPtr &elem) { return elem.get() == top_cell.get(); });
|
||||||
if (iter == top_cell_list_.end()) {
|
if (iter == top_cell_list_.end()) {
|
||||||
MS_LOG(WARNING) << "Can not find top cell " << top_cell.get() << " cell id " << top_cell->cell_id()
|
MS_LOG(WARNING) << "Can not find top cell " << top_cell.get() << " cell id " << top_cell->cell_id()
|
||||||
<< " from top cell list";
|
<< " from top cell list";
|
||||||
|
@ -783,27 +788,14 @@ void GradExecutor::GradNetInner(const prim::GradOperationPtr &grad, const py::ob
|
||||||
MS_EXCEPTION_IF_NULL(top_input_args_info_);
|
MS_EXCEPTION_IF_NULL(top_input_args_info_);
|
||||||
MS_LOG(DEBUG) << "GradNetInner start " << args.size() << ", cell_id " << top_input_args_info_->cell_id
|
MS_LOG(DEBUG) << "GradNetInner start " << args.size() << ", cell_id " << top_input_args_info_->cell_id
|
||||||
<< ", input args info ptr " << top_input_args_info_.get();
|
<< ", input args info ptr " << top_input_args_info_.get();
|
||||||
MS_EXCEPTION_IF_NULL(grad);
|
|
||||||
if (grad->sens_param()) {
|
SetSensValue(grad, top_input_args_info_, args);
|
||||||
MS_LOG(DEBUG) << "Get sens param";
|
// For async, top can not be change when run SetForwardLastNodeInfo; Change top cell after sync
|
||||||
size_t forward_args_size = args.size() - 1;
|
set_top_cell(already_run_top_cell_.at(top_cell()->already_run_cell_id()));
|
||||||
auto sens_v = PyNativeAlgo::DataConvert::PyObjToValue(args[forward_args_size]);
|
|
||||||
const auto &sens_tensor = ConvertOutputValueToTensor(sens_v);
|
|
||||||
// Sens have already exist, which may be need update
|
|
||||||
if (top_input_args_info_->input_arg_value_vec.size() == args.size()) {
|
|
||||||
top_input_args_info_->input_arg_value_vec.pop_back();
|
|
||||||
}
|
|
||||||
(void)top_input_args_info_->input_arg_value_vec.emplace_back(ShallowCopyTensorValue(sens_tensor));
|
|
||||||
top_input_args_info_->has_sens = true;
|
|
||||||
}
|
|
||||||
// For async, top can not be change when run SetForwardLastNodeInfo
|
|
||||||
if (pre_top_cell_ != nullptr) {
|
|
||||||
set_top_cell(pre_top_cell_);
|
|
||||||
}
|
|
||||||
if (!top_cell()->need_compile_graph()) {
|
if (!top_cell()->need_compile_graph()) {
|
||||||
MS_LOG(DEBUG) << "No need compile graph";
|
MS_LOG(DEBUG) << "No need compile graph";
|
||||||
top_cell_list_.pop_back();
|
top_cell_list_.pop_back();
|
||||||
UpdateTopCellInfo(false, false);
|
top_cell()->UpdateTopCellInfo(false, false, false);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "Need compile graph";
|
MS_LOG(DEBUG) << "Need compile graph";
|
||||||
|
@ -829,12 +821,6 @@ void GradExecutor::GetGradGraph(const ad::GradAttr &grad_attr, const std::vector
|
||||||
const std::vector<size_t> &p_args) {
|
const std::vector<size_t> &p_args) {
|
||||||
// Get bprop graph of top cell
|
// Get bprop graph of top cell
|
||||||
auto bprop_graph = GetBpropGraph(grad_attr, w_args, p_args);
|
auto bprop_graph = GetBpropGraph(grad_attr, w_args, p_args);
|
||||||
MS_EXCEPTION_IF_NULL(bprop_graph);
|
|
||||||
bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true);
|
|
||||||
bool use_dynamic_shape_process = !(forward()->device_target() == kAscendDevice) && use_dynamic_shape_process_;
|
|
||||||
bprop_graph->set_flag(kFlagUseDynamicShapeProcess, use_dynamic_shape_process);
|
|
||||||
MS_EXCEPTION_IF_NULL(top_input_args_info_);
|
|
||||||
bprop_graph->set_attr(kAttrFuncGraphCellId, MakeValue(top_input_args_info_->obj_id));
|
|
||||||
auto resource = top_cell()->resource();
|
auto resource = top_cell()->resource();
|
||||||
MS_EXCEPTION_IF_NULL(resource);
|
MS_EXCEPTION_IF_NULL(resource);
|
||||||
resource->set_func_graph(bprop_graph);
|
resource->set_func_graph(bprop_graph);
|
||||||
|
@ -848,7 +834,7 @@ void GradExecutor::GetGradGraph(const ad::GradAttr &grad_attr, const std::vector
|
||||||
(void)TaskEmitAction(resource);
|
(void)TaskEmitAction(resource);
|
||||||
MS_LOG(DEBUG) << "Start execute action";
|
MS_LOG(DEBUG) << "Start execute action";
|
||||||
(void)ExecuteAction(resource);
|
(void)ExecuteAction(resource);
|
||||||
UpdateTopCellInfo(false, false);
|
top_cell()->UpdateTopCellInfo(false, false, true);
|
||||||
resource->Clean();
|
resource->Clean();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -870,7 +856,7 @@ std::vector<AnfNodePtr> GradExecutor::GetWeightsArgs(const py::object &weights,
|
||||||
(void)w_args.emplace_back(fn(weights_tuple[i]));
|
(void)w_args.emplace_back(fn(weights_tuple[i]));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "No parameter tuple get, add weights params by input weight";
|
MS_LOG(DEBUG) << "No parameter tuple get, try get weights params by input weight";
|
||||||
if (py::isinstance<py::tuple>(weights) || py::isinstance<py::list>(weights)) {
|
if (py::isinstance<py::tuple>(weights) || py::isinstance<py::list>(weights)) {
|
||||||
auto weights_tuple = py::cast<py::tuple>(weights);
|
auto weights_tuple = py::cast<py::tuple>(weights);
|
||||||
for (size_t i = 0; i < weights_tuple.size(); ++i) {
|
for (size_t i = 0; i < weights_tuple.size(); ++i) {
|
||||||
|
@ -894,6 +880,10 @@ std::vector<size_t> GradExecutor::GetGradPositionArgs(const py::object &grad_pos
|
||||||
const auto &tuple = grad_position.cast<py::tuple>();
|
const auto &tuple = grad_position.cast<py::tuple>();
|
||||||
(void)std::transform(tuple.begin(), tuple.end(), std::back_inserter(pos_args),
|
(void)std::transform(tuple.begin(), tuple.end(), std::back_inserter(pos_args),
|
||||||
[](const py::handle &elem) { return elem.cast<int64_t>(); });
|
[](const py::handle &elem) { return elem.cast<int64_t>(); });
|
||||||
|
// Return the gradient;
|
||||||
|
if (pos_args.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "grad_position should not be empty when grad by position!";
|
||||||
|
}
|
||||||
return pos_args;
|
return pos_args;
|
||||||
}
|
}
|
||||||
MS_LOG(EXCEPTION) << "Grad position only support tuple when grad_by_position is set True.";
|
MS_LOG(EXCEPTION) << "Grad position only support tuple when grad_by_position is set True.";
|
||||||
|
@ -982,30 +972,33 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const ad::GradAttr &grad_attr, const ve
|
||||||
const vector<size_t> &p_args) {
|
const vector<size_t> &p_args) {
|
||||||
MS_EXCEPTION_IF_NULL(top_input_args_info_);
|
MS_EXCEPTION_IF_NULL(top_input_args_info_);
|
||||||
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
|
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
|
||||||
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
|
|
||||||
FuncGraphPtr bprop_graph = ad::GradPynativeCellEnd(auto_grad_cell_ptr, w_args, p_args, grad_attr);
|
FuncGraphPtr bprop_graph = ad::GradPynativeCellEnd(auto_grad_cell_ptr, w_args, p_args, grad_attr);
|
||||||
MS_EXCEPTION_IF_NULL(bprop_graph);
|
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "Top graph input params size " << top_input_args_info_->input_arg_value_vec.size();
|
MS_LOG(DEBUG) << "Top graph input params size " << top_input_args_info_->input_arg_value_vec.size();
|
||||||
std::ostringstream ss;
|
std::ostringstream ss;
|
||||||
ss << "grad{" << top_input_args_info_->input_arg_value_vec.size() << "}";
|
ss << "grad{" << top_input_args_info_->input_arg_value_vec.size() << "}";
|
||||||
bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
MS_EXCEPTION_IF_NULL(bprop_graph);
|
||||||
bprop_graph->debug_info()->set_name(ss.str());
|
bprop_graph->debug_info()->set_name(ss.str());
|
||||||
UpdateParamAbsByArgs(top_input_args_info_->input_arg_value_vec, bprop_graph, grad_attr.has_sens);
|
UpdateParamAbsByArgs(top_input_args_info_->input_arg_value_vec, bprop_graph, grad_attr.has_sens);
|
||||||
if (top_cell()->need_do_final_opt()) {
|
if (top_cell()->need_do_final_opt()) {
|
||||||
bprop_graph = BpropGraphFinalOpt(bprop_graph);
|
bprop_graph = BpropGraphFinalOpt(bprop_graph, need_renormalize_);
|
||||||
}
|
MS_EXCEPTION_IF_NULL(bprop_graph);
|
||||||
if (top_input_args_info_->is_grad_topest_cell) {
|
|
||||||
need_renormalize_ = false;
|
|
||||||
}
|
}
|
||||||
|
need_renormalize_ = false;
|
||||||
|
bprop_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||||
|
bprop_graph->set_flag(kFlagIsPynativeBpropGraph, true);
|
||||||
|
bool use_dynamic_shape_process = !(forward()->device_target() == kAscendDevice) && use_dynamic_shape_process_;
|
||||||
|
bprop_graph->set_flag(kFlagUseDynamicShapeProcess, use_dynamic_shape_process);
|
||||||
|
bprop_graph->set_attr(kAttrFuncGraphCellId, MakeValue(top_input_args_info_->obj_id));
|
||||||
return bprop_graph;
|
return bprop_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GradExecutor::SetGradOrder(const std::string &cell_id) {
|
void GradExecutor::SetGradOrder(const std::string &obj_id) {
|
||||||
// top_cell_ == nullptr means call by grad first
|
// top_cell_ == nullptr means call by grad first
|
||||||
// Args of CheckAlreadyRun may be have sens arg, so cell_id is include top cell id,
|
// top_cell_->obj_id_with_grad_order() include obj_id and grad_order
|
||||||
// If cell_id.find(top_cell_->cell_id()) == std::string::npos, means current cell is not top cell, may be high order
|
// If top_cell_->obj_id_with_grad_order().find(obj_id) == std::string::npos, means current cell is not top cell, grad
|
||||||
if (top_cell_ == nullptr || cell_id.find(top_cell_->c_cell_id()) == std::string::npos) {
|
// high order come in
|
||||||
|
if (top_cell_ == nullptr || top_cell_->obj_id_with_grad_order().find(obj_id) == std::string::npos) {
|
||||||
IncreaseGradOrder();
|
IncreaseGradOrder();
|
||||||
}
|
}
|
||||||
if (!grad_is_running_) {
|
if (!grad_is_running_) {
|
||||||
|
@ -1016,10 +1009,10 @@ void GradExecutor::SetGradOrder(const std::string &cell_id) {
|
||||||
|
|
||||||
py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj,
|
py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj,
|
||||||
const py::object &grad_hash_id, const py::args &args) {
|
const py::object &grad_hash_id, const py::args &args) {
|
||||||
auto cell_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj);
|
const auto &obj_id = PyNativeAlgo::PyParser::GetIdByPyObj(obj);
|
||||||
|
|
||||||
// Check current cell grad order and erase it if in current top cell list
|
// Check current cell grad order and erase it if in current top cell list
|
||||||
SetGradOrder(cell_id);
|
SetGradOrder(obj_id);
|
||||||
// Include weight param size and required grad flag
|
// Include weight param size and required grad flag
|
||||||
std::string grad_hash_id_str;
|
std::string grad_hash_id_str;
|
||||||
if (!py::isinstance<py::none>(grad_hash_id)) {
|
if (!py::isinstance<py::none>(grad_hash_id)) {
|
||||||
|
@ -1036,7 +1029,7 @@ py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, con
|
||||||
// check whether need to run forward process
|
// check whether need to run forward process
|
||||||
bool forward_run = false;
|
bool forward_run = false;
|
||||||
if (input_args_info_stack_.empty() && top_cell_ != nullptr) {
|
if (input_args_info_stack_.empty() && top_cell_ != nullptr) {
|
||||||
const auto &check_already_run_cell_id = GetAlreadyRunCellId(cell_id);
|
const auto &check_already_run_cell_id = GetAlreadyRunCellId(obj_id);
|
||||||
auto find_top_cell = GetTopCell(check_already_run_cell_id);
|
auto find_top_cell = GetTopCell(check_already_run_cell_id);
|
||||||
if (find_top_cell != nullptr) {
|
if (find_top_cell != nullptr) {
|
||||||
MS_LOG(DEBUG) << "Find already run top cell";
|
MS_LOG(DEBUG) << "Find already run top cell";
|
||||||
|
@ -1048,7 +1041,7 @@ py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, con
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "Graph have already ran " << forward_run << " top cell id " << cell_id;
|
MS_LOG(DEBUG) << "Graph have already ran " << forward_run << " top cell id " << obj_id;
|
||||||
return BaseRefToPyData(forward_run);
|
return BaseRefToPyData(forward_run);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1087,12 +1080,20 @@ void GradExecutor::MakeNestedCnode(bool has_custom_bprop, const std::vector<Valu
|
||||||
ClearGradRes();
|
ClearGradRes();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
FuncGraphPtr first_grad_fg = cur_run_bprop_graph;
|
// High grad hit cache
|
||||||
|
if (!top_cell()->vm_compile()) {
|
||||||
|
SwitchTopCell();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto first_grad_fg = cur_run_bprop_graph;
|
||||||
if (has_custom_bprop) {
|
if (has_custom_bprop) {
|
||||||
first_grad_fg = curr_g();
|
first_grad_fg = curr_g();
|
||||||
MS_LOG(DEBUG) << "Bprop nested";
|
MS_LOG(DEBUG) << "Bprop nested";
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(first_grad_fg);
|
MS_EXCEPTION_IF_NULL(first_grad_fg);
|
||||||
|
// Because ConvertPrimToPrimPy will change first_grad_fg, when hit bprop graph cache
|
||||||
|
// resource->func_graph() will be changed, abstract may be nullptr.
|
||||||
|
first_grad_fg = BasicClone(first_grad_fg);
|
||||||
std::vector<AnfNodePtr> inputs{NewValueNode(first_grad_fg)};
|
std::vector<AnfNodePtr> inputs{NewValueNode(first_grad_fg)};
|
||||||
ValuePtrList weights_args;
|
ValuePtrList weights_args;
|
||||||
DoParameterReplace(first_grad_fg, forward_args, &inputs, &weights_args);
|
DoParameterReplace(first_grad_fg, forward_args, &inputs, &weights_args);
|
||||||
|
@ -1162,18 +1163,18 @@ void GradExecutor::DoParameterReplace(const FuncGraphPtr &first_grad_fg, const s
|
||||||
|
|
||||||
// Replace weights param
|
// Replace weights param
|
||||||
MS_EXCEPTION_IF_NULL(weights_args);
|
MS_EXCEPTION_IF_NULL(weights_args);
|
||||||
mindspore::HashSet<std::string> graph_weights_set;
|
mindspore::HashSet<std::string> inner_graph_used_weights_set;
|
||||||
// Weight in graph
|
// Weight in inner graph
|
||||||
const auto &fir_graph_parameters = first_grad_fg->parameters();
|
const auto &fir_graph_parameters = first_grad_fg->parameters();
|
||||||
for (const auto ¶m : fir_graph_parameters) {
|
for (const auto ¶m : fir_graph_parameters) {
|
||||||
auto weight_tensor = PyNativeAlgo::Common::GetTensorFromParam(param);
|
auto weight_tensor = PyNativeAlgo::Common::GetTensorFromParam(param);
|
||||||
if (weight_tensor != nullptr) {
|
if (weight_tensor != nullptr) {
|
||||||
(void)graph_weights_set.emplace(weight_tensor->id());
|
(void)inner_graph_used_weights_set.emplace(weight_tensor->id());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (const auto &weight : inner_graph_info->weight_params) {
|
for (const auto &weight : inner_graph_info->weight_params) {
|
||||||
// If weight used in graph, but not need get grad by gradnet, so will not process in outer graph
|
// If weight used in graph, but not need get grad by gradnet, it will be a valuenode, no need replace
|
||||||
if (graph_weights_set.find(weight.first) == graph_weights_set.end()) {
|
if (inner_graph_used_weights_set.find(weight.first) == inner_graph_used_weights_set.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const auto it = outer_graph_info->weight_params.find(weight.first);
|
const auto it = outer_graph_info->weight_params.find(weight.first);
|
||||||
|
@ -1351,8 +1352,6 @@ AnfNodePtr GradExecutor::GetValueSequenceInput(const ValuePtr &v, const std::str
|
||||||
auto cnode = curr_g()->NewCNode(inputs);
|
auto cnode = curr_g()->NewCNode(inputs);
|
||||||
MS_LOG(DEBUG) << "Create make tuple node: " << cnode->DebugString();
|
MS_LOG(DEBUG) << "Create make tuple node: " << cnode->DebugString();
|
||||||
top_cell()->SetNodeMapInGraphInfoMap(obj_id, cnode, -1, false);
|
top_cell()->SetNodeMapInGraphInfoMap(obj_id, cnode, -1, false);
|
||||||
CheckGraphDynamic(cnode, top_cell()->op_index());
|
|
||||||
top_cell()->IncreaseOpIndex();
|
|
||||||
return cnode;
|
return cnode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1505,7 +1504,6 @@ void GradExecutor::AsyncGradPynativeOp(const ad::AutoGradCellImplPtr &auto_grad_
|
||||||
|
|
||||||
void GradExecutor::AsyncUpdateOutputNodeOfTopCell(const AnfNodePtr &output_node, const ValuePtr &cloned_value) const {
|
void GradExecutor::AsyncUpdateOutputNodeOfTopCell(const AnfNodePtr &output_node, const ValuePtr &cloned_value) const {
|
||||||
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
|
auto auto_grad_cell_ptr = top_cell()->auto_grad_cell_ptr();
|
||||||
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
|
|
||||||
const auto fn = [auto_grad_cell_ptr, output_node, cloned_value]() {
|
const auto fn = [auto_grad_cell_ptr, output_node, cloned_value]() {
|
||||||
auto_grad_cell_ptr->UpdateOutputNodeOfTopCell(output_node, cloned_value);
|
auto_grad_cell_ptr->UpdateOutputNodeOfTopCell(output_node, cloned_value);
|
||||||
};
|
};
|
||||||
|
@ -1722,7 +1720,7 @@ CNodePtr GradExecutor::ConstructForwardGraph(const FrontendOpRunInfoPtr &op_run_
|
||||||
}
|
}
|
||||||
const auto &cnode = curr_g()->NewCNodeInOrder(inputs);
|
const auto &cnode = curr_g()->NewCNodeInOrder(inputs);
|
||||||
if (IsPrimitiveCNode(cnode, prim::kPrimCellBackwardHook)) {
|
if (IsPrimitiveCNode(cnode, prim::kPrimCellBackwardHook)) {
|
||||||
top_cell()->RecordCellBackwardHookOp(GetCurCellOrder(), cnode);
|
top_cell()->RecordCellBackwardHookOp(top_input_args_info()->obj_order_id, cnode);
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "Make CNode for " << op_run_info->base_op_run_info.op_name << ", new cnode is "
|
MS_LOG(DEBUG) << "Make CNode for " << op_run_info->base_op_run_info.op_name << ", new cnode is "
|
||||||
|
@ -1745,7 +1743,7 @@ void GradExecutor::SetBpropGraphJitLevel(const py::object &obj) const {
|
||||||
graph_executor->SetJitConfig(jit_config_dict);
|
graph_executor->SetJitConfig(jit_config_dict);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GradExecutor::SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, const size_t &node_idx,
|
void GradExecutor::SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, const size_t node_idx,
|
||||||
bool is_ms_function_node,
|
bool is_ms_function_node,
|
||||||
const std::string &graph_phase) const {
|
const std::string &graph_phase) const {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
@ -1760,7 +1758,7 @@ void GradExecutor::SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, c
|
||||||
node_info->input_values[i] = GetValueNode(input_node);
|
node_info->input_values[i] = GetValueNode(input_node);
|
||||||
} else if (input_node->isa<CNode>()) {
|
} else if (input_node->isa<CNode>()) {
|
||||||
const auto &node_abs = input_node->abstract();
|
const auto &node_abs = input_node->abstract();
|
||||||
auto op_index = top_cell()->get_op_index_by_cnode_hash(input_node->hash());
|
auto op_index = top_cell()->get_op_index_by_cnode_hash(input_node->hash(), node_idx);
|
||||||
node_info->input_cnode_info[i] = std::make_pair(op_index, node_abs);
|
node_info->input_cnode_info[i] = std::make_pair(op_index, node_abs);
|
||||||
} else {
|
} else {
|
||||||
if (!input_node->isa<Parameter>()) {
|
if (!input_node->isa<Parameter>()) {
|
||||||
|
@ -1778,11 +1776,12 @@ void GradExecutor::SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, c
|
||||||
node_info->graph_phase = graph_phase;
|
node_info->graph_phase = graph_phase;
|
||||||
}
|
}
|
||||||
top_cell()->set_cnode_hash_with_op_index(cnode->hash(), node_idx);
|
top_cell()->set_cnode_hash_with_op_index(cnode->hash(), node_idx);
|
||||||
const auto &cell_id = top_cell()->c_cell_id() + "_" + std::to_string(top_cell()->grad_order());
|
(void)cell_id_with_dynamic_detect_nodes_[top_cell()->obj_id_with_grad_order()].emplace_back(node_info);
|
||||||
(void)cell_id_with_dynamic_detect_nodes_[cell_id].emplace_back(node_info);
|
MS_LOG(DEBUG) << "Save node " << cnode->DebugString() << " firstly, node_idx: " << node_idx
|
||||||
|
<< ", is_ms_function_node: " << is_ms_function_node << ", graph_phase:" << graph_phase;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GradExecutor::IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
|
bool GradExecutor::IsGraphDynamic(const CNodePtr &cnode, const size_t node_idx, bool is_ms_function_node,
|
||||||
const std::string &graph_phase) const {
|
const std::string &graph_phase) const {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
if (!is_cell_id_in_dynamic_detect_nodes_map_) {
|
if (!is_cell_id_in_dynamic_detect_nodes_map_) {
|
||||||
|
@ -1791,15 +1790,28 @@ bool GradExecutor::IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx,
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto &cell_id = top_cell()->c_cell_id() + "_" + std::to_string(top_cell()->grad_order());
|
MS_LOG(DEBUG) << "Check node " << cnode->DebugString() << " node_idx: " << node_idx
|
||||||
const auto &dynamic_nodes = cell_id_with_dynamic_detect_nodes_[cell_id];
|
<< ", is_ms_function_node: " << is_ms_function_node << ", graph_phase:" << graph_phase;
|
||||||
|
const auto &dynamic_nodes = cell_id_with_dynamic_detect_nodes_[top_cell()->obj_id_with_grad_order()];
|
||||||
if (node_idx >= dynamic_nodes.size()) {
|
if (node_idx >= dynamic_nodes.size()) {
|
||||||
MS_LOG(DEBUG) << "Old dynamic_nodes size: " << dynamic_nodes.size() << " cur node_idx is: " << node_idx
|
MS_LOG(DEBUG) << "Old dynamic_nodes size: " << dynamic_nodes.size() << ", cur node_idx is: " << node_idx
|
||||||
<< ", graph is dynamic.";
|
<< ", graph is dynamic.";
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (IsDynamicDetectNodeInfoChange(dynamic_nodes[node_idx], cnode, is_ms_function_node, graph_phase, top_cell())) {
|
// 1.Detect ms_function phase
|
||||||
|
const DynamicDetectNodeInfoPtr &old_node_info = dynamic_nodes[node_idx];
|
||||||
|
if (is_ms_function_node) {
|
||||||
|
if (!old_node_info->is_graph_node || graph_phase != old_node_info->graph_phase) {
|
||||||
|
MS_LOG(DEBUG) << "Graph is dynamic, old is_graph_node: " << old_node_info->is_graph_node
|
||||||
|
<< " new is_graph_node: " << is_ms_function_node << " old graph_phase "
|
||||||
|
<< old_node_info->graph_phase << " new graph_phase: " << graph_phase;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (IsDynamicDetectNodeInfoChange(old_node_info, node_idx, cnode, top_cell())) {
|
||||||
MS_LOG(DEBUG) << "Graph is dynamic, node_idx: " << node_idx
|
MS_LOG(DEBUG) << "Graph is dynamic, node_idx: " << node_idx
|
||||||
<< " is different, cnode: " << cnode->fullname_with_scope();
|
<< " is different, cnode: " << cnode->fullname_with_scope();
|
||||||
return true;
|
return true;
|
||||||
|
@ -1808,7 +1820,7 @@ bool GradExecutor::IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx,
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GradExecutor::CheckGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
|
void GradExecutor::CheckGraphDynamic(const CNodePtr &cnode, const size_t node_idx, bool is_ms_function_node,
|
||||||
const std::string &graph_phase) const {
|
const std::string &graph_phase) const {
|
||||||
if (use_dynamic_shape_process_) {
|
if (use_dynamic_shape_process_) {
|
||||||
return;
|
return;
|
||||||
|
@ -1816,9 +1828,7 @@ void GradExecutor::CheckGraphDynamic(const CNodePtr &cnode, const size_t &node_i
|
||||||
|
|
||||||
use_dynamic_shape_process_ = IsGraphDynamic(cnode, node_idx, is_ms_function_node, graph_phase);
|
use_dynamic_shape_process_ = IsGraphDynamic(cnode, node_idx, is_ms_function_node, graph_phase);
|
||||||
if (use_dynamic_shape_process_) {
|
if (use_dynamic_shape_process_) {
|
||||||
MS_LOG(DEBUG) << "Cnode: " << cnode->fullname_with_scope() << ", node_idx: " << node_idx
|
MS_LOG(DEBUG) << "Set use_dynamic_shape_process_: " << use_dynamic_shape_process_;
|
||||||
<< ", is_ms_function_node: " << is_ms_function_node << ", graph_phase:" << graph_phase
|
|
||||||
<< ", use_dynamic_shape_process_: " << use_dynamic_shape_process_;
|
|
||||||
cell_id_with_dynamic_detect_nodes_.clear();
|
cell_id_with_dynamic_detect_nodes_.clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -82,6 +82,10 @@ class GradExecutor {
|
||||||
inline void set_use_dynamic_shape_process(bool use_dynamic_shape_process) {
|
inline void set_use_dynamic_shape_process(bool use_dynamic_shape_process) {
|
||||||
use_dynamic_shape_process_ = use_dynamic_shape_process;
|
use_dynamic_shape_process_ = use_dynamic_shape_process;
|
||||||
}
|
}
|
||||||
|
inline InputArgsInfoPtr top_input_args_info() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(top_input_args_info_);
|
||||||
|
return top_input_args_info_;
|
||||||
|
}
|
||||||
|
|
||||||
inline bool need_renormalize() const { return need_renormalize_; }
|
inline bool need_renormalize() const { return need_renormalize_; }
|
||||||
inline void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); }
|
inline void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); }
|
||||||
|
@ -110,15 +114,14 @@ class GradExecutor {
|
||||||
const std::vector<tensor::TensorPtr> &pre_tensors) const;
|
const std::vector<tensor::TensorPtr> &pre_tensors) const;
|
||||||
void ClearRes();
|
void ClearRes();
|
||||||
void WorkerJoin() { async_executor_->WorkerJoin(); }
|
void WorkerJoin() { async_executor_->WorkerJoin(); }
|
||||||
void CheckGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node = false,
|
void CheckGraphDynamic(const CNodePtr &cnode, const size_t node_idx, bool is_ms_function_node = false,
|
||||||
const std::string &graph_phase = "") const;
|
const std::string &graph_phase = "") const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ForwardExecutorPtr forward() const;
|
ForwardExecutorPtr forward() const;
|
||||||
inline FuncGraphPtr curr_g() const { return top_cell()->fg(); }
|
inline FuncGraphPtr curr_g() const { return top_cell()->fg(); }
|
||||||
inline void PushHighOrderGraphStack(const TopCellInfoPtr &top_cell) { high_order_stack_.push(top_cell); }
|
inline void PushHighOrderGraphStack(const TopCellInfoPtr &top_cell) { high_order_stack_.push(top_cell); }
|
||||||
std::string GetCurCellOrder() const;
|
void SetGradOrder(const std::string &obj_id);
|
||||||
void SetGradOrder(const std::string &cell_id);
|
|
||||||
void SaveOutputNodeMap(const std::string &obj_id, const FrontendOpRunInfoPtr &op_run_info,
|
void SaveOutputNodeMap(const std::string &obj_id, const FrontendOpRunInfoPtr &op_run_info,
|
||||||
const CNodePtr &cnode) const;
|
const CNodePtr &cnode) const;
|
||||||
void DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode, const ValuePtr &op_out) const;
|
void DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode, const ValuePtr &op_out) const;
|
||||||
|
@ -154,7 +157,6 @@ class GradExecutor {
|
||||||
void HandleInputArgsForTopCell(const InputArgsInfoPtr &input_args_info, bool is_bprop_top) const;
|
void HandleInputArgsForTopCell(const InputArgsInfoPtr &input_args_info, bool is_bprop_top) const;
|
||||||
void InitResourceAndDfBuilder(const InputArgsInfoPtr &cell_info);
|
void InitResourceAndDfBuilder(const InputArgsInfoPtr &cell_info);
|
||||||
void MakeNewTopGraph(const InputArgsInfoPtr &input_args_info);
|
void MakeNewTopGraph(const InputArgsInfoPtr &input_args_info);
|
||||||
void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph) const;
|
|
||||||
|
|
||||||
// Manage resource when run grad process.
|
// Manage resource when run grad process.
|
||||||
bool IsBpropGraph(const std::string &cell_id) const;
|
bool IsBpropGraph(const std::string &cell_id) const;
|
||||||
|
@ -163,6 +165,8 @@ class GradExecutor {
|
||||||
void NewGraphImpl(const InputArgsInfoPtr &input_args_info);
|
void NewGraphImpl(const InputArgsInfoPtr &input_args_info);
|
||||||
void AsyncNewGraphImpl(const InputArgsInfoPtr &input_args_info);
|
void AsyncNewGraphImpl(const InputArgsInfoPtr &input_args_info);
|
||||||
void EndGraphInner(const py::object &obj, const py::object &out, const py::args &args);
|
void EndGraphInner(const py::object &obj, const py::object &out, const py::args &args);
|
||||||
|
void UpdateInputArgsInfo(const InputArgsInfoPtr &input_args_info, const py::object &obj, const py::object &out,
|
||||||
|
const py::args &args);
|
||||||
void EndGraphImpl(const InputArgsInfoPtr &input_args_info);
|
void EndGraphImpl(const InputArgsInfoPtr &input_args_info);
|
||||||
void AsyncEndGraphImpl(const InputArgsInfoPtr &input_args_info);
|
void AsyncEndGraphImpl(const InputArgsInfoPtr &input_args_info);
|
||||||
void SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const;
|
void SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const;
|
||||||
|
@ -188,9 +192,9 @@ class GradExecutor {
|
||||||
AnfNodePtr CreateTupleGetItemNode(const std::string &obj_id,
|
AnfNodePtr CreateTupleGetItemNode(const std::string &obj_id,
|
||||||
const std::pair<AnfNodePtr, std::vector<int64_t>> &out) const;
|
const std::pair<AnfNodePtr, std::vector<int64_t>> &out) const;
|
||||||
|
|
||||||
void SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
|
void SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, const size_t node_idx, bool is_ms_function_node,
|
||||||
const std::string &graph_phase) const;
|
const std::string &graph_phase) const;
|
||||||
bool IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
|
bool IsGraphDynamic(const CNodePtr &cnode, const size_t node_idx, bool is_ms_function_node,
|
||||||
const std::string &graph_phase) const;
|
const std::string &graph_phase) const;
|
||||||
|
|
||||||
bool grad_flag_{false};
|
bool grad_flag_{false};
|
||||||
|
@ -200,16 +204,11 @@ class GradExecutor {
|
||||||
mutable bool use_dynamic_shape_process_{false};
|
mutable bool use_dynamic_shape_process_{false};
|
||||||
mutable bool is_cell_id_in_dynamic_detect_nodes_map_{false};
|
mutable bool is_cell_id_in_dynamic_detect_nodes_map_{false};
|
||||||
int custom_bprop_cell_count_{0};
|
int custom_bprop_cell_count_{0};
|
||||||
|
size_t obj_order_{0};
|
||||||
// Used in sub thread
|
|
||||||
size_t cell_order_{0};
|
|
||||||
std::string cur_cell_id_;
|
|
||||||
|
|
||||||
// If grad_order=1, indicate first derivative; grad_order=2, indicate second derivative; ...
|
// If grad_order=1, indicate first derivative; grad_order=2, indicate second derivative; ...
|
||||||
size_t grad_order_{0};
|
size_t grad_order_{0};
|
||||||
std::string grad_operation_;
|
std::string grad_operation_;
|
||||||
TopCellInfoPtr top_cell_{nullptr};
|
TopCellInfoPtr top_cell_{nullptr};
|
||||||
TopCellInfoPtr pre_top_cell_{nullptr};
|
|
||||||
InputArgsInfoPtr top_input_args_info_{nullptr};
|
InputArgsInfoPtr top_input_args_info_{nullptr};
|
||||||
// Records every cell info for share, regardless of whether need construct grad graph
|
// Records every cell info for share, regardless of whether need construct grad graph
|
||||||
std::stack<InputArgsInfoPtr> input_args_info_stack_;
|
std::stack<InputArgsInfoPtr> input_args_info_stack_;
|
||||||
|
|
|
@ -230,11 +230,11 @@ void MsFunction::GetWeightsNode(const FrontendOpRunInfoPtr &op_run_info, const G
|
||||||
} else {
|
} else {
|
||||||
top_cell->fg()->add_parameter(param);
|
top_cell->fg()->add_parameter(param);
|
||||||
param->debug_info()->set_name(param->name());
|
param->debug_info()->set_name(param->name());
|
||||||
|
top_cell->SetParamNodeMapInGraphInfoMap(tensor_value->id(), param, true);
|
||||||
}
|
}
|
||||||
(void)new_params.emplace_back(param);
|
(void)new_params.emplace_back(param);
|
||||||
(void)input_nodes->emplace_back(param);
|
(void)input_nodes->emplace_back(param);
|
||||||
(void)op_run_info->input_value.emplace_back(tensor_value);
|
(void)op_run_info->input_value.emplace_back(tensor_value);
|
||||||
top_cell->SetParamNodeMapInGraphInfoMap(tensor_value->id(), param, true);
|
|
||||||
MS_LOG(DEBUG) << "Top graph set free parameter " << param->DebugString() << ". Its default value is "
|
MS_LOG(DEBUG) << "Top graph set free parameter " << param->DebugString() << ". Its default value is "
|
||||||
<< tensor_value->ToString() << ". Its name is: " << param->name();
|
<< tensor_value->ToString() << ". Its name is: " << param->name();
|
||||||
}
|
}
|
||||||
|
|
|
@ -90,6 +90,12 @@ void TopCellInfo::GetOpInfo(const FrontendOpRunInfoPtr &op_run_info) {
|
||||||
++op_index_;
|
++op_index_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TopCellInfo::UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compile) {
|
||||||
|
need_compile_graph_ = need_compile_graph;
|
||||||
|
forward_already_run_ = forward_already_run;
|
||||||
|
vm_compile_ = vm_compile;
|
||||||
|
}
|
||||||
|
|
||||||
void TopCellInfo::ClearDeviceMemory() const {
|
void TopCellInfo::ClearDeviceMemory() const {
|
||||||
MS_LOG(DEBUG) << "Clear device memory in value nodes of bprop graph, top cell: " << cell_id_;
|
MS_LOG(DEBUG) << "Clear device memory in value nodes of bprop graph, top cell: " << cell_id_;
|
||||||
auto ms_context = MsContext::GetInstance();
|
auto ms_context = MsContext::GetInstance();
|
||||||
|
@ -132,6 +138,7 @@ void TopCellInfo::Clear() {
|
||||||
is_init_kpynative_ = false;
|
is_init_kpynative_ = false;
|
||||||
need_compile_graph_ = false;
|
need_compile_graph_ = false;
|
||||||
forward_already_run_ = false;
|
forward_already_run_ = false;
|
||||||
|
vm_compile_ = false;
|
||||||
op_index_ = 0;
|
op_index_ = 0;
|
||||||
resource_ = nullptr;
|
resource_ = nullptr;
|
||||||
fg_ = nullptr;
|
fg_ = nullptr;
|
||||||
|
|
|
@ -58,11 +58,11 @@ using GraphInfoPtr = std::shared_ptr<GraphInfo>;
|
||||||
class TopCellInfo {
|
class TopCellInfo {
|
||||||
public:
|
public:
|
||||||
~TopCellInfo() = default;
|
~TopCellInfo() = default;
|
||||||
TopCellInfo(bool is_high_order_top_cell, size_t grad_order, std::string c_cell_id, std::string cellid,
|
TopCellInfo(bool is_high_order_top_cell, size_t grad_order, std::string obj_id_with_grad_order, std::string cellid,
|
||||||
std::string already_run_cell_id, pipeline::ResourcePtr r, FuncGraphPtr fg)
|
std::string already_run_cell_id, pipeline::ResourcePtr r, FuncGraphPtr fg)
|
||||||
: is_high_order_top_cell_(is_high_order_top_cell),
|
: is_high_order_top_cell_(is_high_order_top_cell),
|
||||||
grad_order_(grad_order),
|
grad_order_(grad_order),
|
||||||
c_cell_id_(std::move(c_cell_id)),
|
obj_id_with_grad_order_(std::move(obj_id_with_grad_order)),
|
||||||
cell_id_(std::move(cellid)),
|
cell_id_(std::move(cellid)),
|
||||||
already_run_cell_id_(std::move(already_run_cell_id)),
|
already_run_cell_id_(std::move(already_run_cell_id)),
|
||||||
resource_(std::move(r)),
|
resource_(std::move(r)),
|
||||||
|
@ -81,6 +81,7 @@ class TopCellInfo {
|
||||||
inline void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
|
inline void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
|
||||||
inline bool need_compile_graph() const { return need_compile_graph_; }
|
inline bool need_compile_graph() const { return need_compile_graph_; }
|
||||||
inline void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
|
inline void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
|
||||||
|
inline bool vm_compile() const { return vm_compile_; }
|
||||||
inline bool is_high_order_top_cell() const { return is_high_order_top_cell_; }
|
inline bool is_high_order_top_cell() const { return is_high_order_top_cell_; }
|
||||||
inline void set_need_do_final_opt(bool need_do_final_opt) { need_do_final_opt_ = need_do_final_opt; }
|
inline void set_need_do_final_opt(bool need_do_final_opt) { need_do_final_opt_ = need_do_final_opt; }
|
||||||
inline bool need_do_final_opt() const { return need_do_final_opt_; }
|
inline bool need_do_final_opt() const { return need_do_final_opt_; }
|
||||||
|
@ -91,7 +92,7 @@ class TopCellInfo {
|
||||||
}
|
}
|
||||||
inline void set_fg(const FuncGraphPtr &fg) { fg_ = fg; }
|
inline void set_fg(const FuncGraphPtr &fg) { fg_ = fg; }
|
||||||
inline const std::string &cell_id() const { return cell_id_; }
|
inline const std::string &cell_id() const { return cell_id_; }
|
||||||
inline const std::string &c_cell_id() const { return c_cell_id_; }
|
inline const std::string &obj_id_with_grad_order() const { return obj_id_with_grad_order_; }
|
||||||
inline const std::string &already_run_cell_id() const { return already_run_cell_id_; }
|
inline const std::string &already_run_cell_id() const { return already_run_cell_id_; }
|
||||||
inline void set_input_args_id(const std::string &input_args_id) { input_args_id_ = input_args_id; }
|
inline void set_input_args_id(const std::string &input_args_id) { input_args_id_ = input_args_id; }
|
||||||
inline const std::string &input_args_id() const { return input_args_id_; }
|
inline const std::string &input_args_id() const { return input_args_id_; }
|
||||||
|
@ -124,10 +125,11 @@ class TopCellInfo {
|
||||||
inline void set_cnode_hash_with_op_index(const size_t &node_hash, const size_t &op_index) {
|
inline void set_cnode_hash_with_op_index(const size_t &node_hash, const size_t &op_index) {
|
||||||
cnode_hash_with_op_index_[node_hash] = op_index;
|
cnode_hash_with_op_index_[node_hash] = op_index;
|
||||||
}
|
}
|
||||||
inline size_t get_op_index_by_cnode_hash(const size_t &node_hash) const {
|
inline size_t get_op_index_by_cnode_hash(const size_t node_hash, const size_t node_idx) const {
|
||||||
const auto iter = cnode_hash_with_op_index_.find(node_hash);
|
const auto iter = cnode_hash_with_op_index_.find(node_hash);
|
||||||
if (iter == cnode_hash_with_op_index_.end()) {
|
if (iter == cnode_hash_with_op_index_.end()) {
|
||||||
MS_LOG(EXCEPTION) << "hash:" << node_hash << " is not found in cnode_hash_with_op_index_";
|
MS_LOG(DEBUG) << "hash:" << node_hash << " is not found in cnode_hash_with_op_index_";
|
||||||
|
return node_idx;
|
||||||
}
|
}
|
||||||
return iter->second;
|
return iter->second;
|
||||||
}
|
}
|
||||||
|
@ -136,6 +138,7 @@ class TopCellInfo {
|
||||||
void SetParamNodeMapInGraphInfoMap(const std::string &id, const ParameterPtr ¶m, bool is_weight = false) const;
|
void SetParamNodeMapInGraphInfoMap(const std::string &id, const ParameterPtr ¶m, bool is_weight = false) const;
|
||||||
void SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index = -1,
|
void SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index = -1,
|
||||||
bool need_save_sub_id = true) const;
|
bool need_save_sub_id = true) const;
|
||||||
|
void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compile);
|
||||||
void ClearDeviceMemory() const;
|
void ClearDeviceMemory() const;
|
||||||
void Clear();
|
void Clear();
|
||||||
|
|
||||||
|
@ -150,11 +153,13 @@ class TopCellInfo {
|
||||||
bool is_init_kpynative_{false};
|
bool is_init_kpynative_{false};
|
||||||
bool forward_already_run_{false};
|
bool forward_already_run_{false};
|
||||||
bool need_compile_graph_{false};
|
bool need_compile_graph_{false};
|
||||||
size_t op_index_{0};
|
bool vm_compile_{false};
|
||||||
|
bool is_run_cell_{false};
|
||||||
bool is_high_order_top_cell_{false};
|
bool is_high_order_top_cell_{false};
|
||||||
bool need_do_final_opt_{false};
|
bool need_do_final_opt_{false};
|
||||||
size_t grad_order_{0};
|
size_t grad_order_{0};
|
||||||
std::string c_cell_id_;
|
size_t op_index_{0};
|
||||||
|
std::string obj_id_with_grad_order_;
|
||||||
std::string cell_id_;
|
std::string cell_id_;
|
||||||
std::string already_run_cell_id_;
|
std::string already_run_cell_id_;
|
||||||
std::string input_args_id_;
|
std::string input_args_id_;
|
||||||
|
|
|
@ -29,6 +29,18 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace pynative {
|
namespace pynative {
|
||||||
namespace PyNativeAlgo {
|
namespace PyNativeAlgo {
|
||||||
|
namespace {
|
||||||
|
void ClonePrim(const FrontendOpRunInfoPtr &op_run_info) {
|
||||||
|
// Clone a new prim
|
||||||
|
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||||
|
op_run_info->op_prim = std::make_shared<PrimitivePy>(*(op_run_info->op_prim));
|
||||||
|
MS_EXCEPTION_IF_NULL(op_run_info->op_prim->adapter());
|
||||||
|
if (op_run_info->op_prim->adapter()->attached_primitive() == nullptr) {
|
||||||
|
op_run_info->op_prim->adapter()->set_attached_primitive(op_run_info->op_prim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
std::string Common::GetIdByValue(const ValuePtr &v) {
|
std::string Common::GetIdByValue(const ValuePtr &v) {
|
||||||
MS_EXCEPTION_IF_NULL(v);
|
MS_EXCEPTION_IF_NULL(v);
|
||||||
if (v->isa<tensor::Tensor>()) {
|
if (v->isa<tensor::Tensor>()) {
|
||||||
|
@ -517,12 +529,8 @@ void DataConvert::GetInputTensor(const FrontendOpRunInfoPtr &op_run_info, const
|
||||||
bool need_convert_input_to_attr = NeedConvertConstInputToAttr(op_run_info, device_target, &input_to_attr);
|
bool need_convert_input_to_attr = NeedConvertConstInputToAttr(op_run_info, device_target, &input_to_attr);
|
||||||
MS_LOG(DEBUG) << "Need convert input to addr " << need_convert_input_to_attr;
|
MS_LOG(DEBUG) << "Need convert input to addr " << need_convert_input_to_attr;
|
||||||
if (need_convert_input_to_attr) {
|
if (need_convert_input_to_attr) {
|
||||||
// Clone a new prim
|
// Prim may be changed attr
|
||||||
op_run_info->op_prim = std::make_shared<PrimitivePy>(*(op_run_info->op_prim));
|
ClonePrim(op_run_info);
|
||||||
MS_EXCEPTION_IF_NULL(op_run_info->op_prim->adapter());
|
|
||||||
if (op_run_info->op_prim->adapter()->attached_primitive() == nullptr) {
|
|
||||||
op_run_info->op_prim->adapter()->set_attached_primitive(op_run_info->op_prim);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
const auto &op_prim = op_run_info->op_prim;
|
const auto &op_prim = op_run_info->op_prim;
|
||||||
|
|
||||||
|
@ -538,10 +546,17 @@ void DataConvert::GetInputTensor(const FrontendOpRunInfoPtr &op_run_info, const
|
||||||
// Mark tensors, common tensor data : 0, weight param: 1, valuenode(float_, int_): 2
|
// Mark tensors, common tensor data : 0, weight param: 1, valuenode(float_, int_): 2
|
||||||
ConvertValueToTensor(op_run_info, input_object, index, op_prim);
|
ConvertValueToTensor(op_run_info, input_object, index, op_prim);
|
||||||
// -1 indicates input_object is not a dynInput
|
// -1 indicates input_object is not a dynInput
|
||||||
if (op_prim->HasAttr(kAttrDynInputSizes) && !input_object->isa<ValueSequence>()) {
|
if (op_prim->HasAttr(kAttrDynInputSizes)) {
|
||||||
auto dyn_v = GetValue<const std::vector<int64_t>>(op_prim->GetAttr(kAttrDynInputSizes));
|
if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)) {
|
||||||
(void)dyn_v.emplace_back(-1);
|
// Like addn, prim define in python, but number of inputs change, so the value of kAttrDynInputSizes
|
||||||
op_prim->set_attr(kAttrDynInputSizes, MakeValue(dyn_v));
|
// changed too. In async, do opgrad may be not complete.
|
||||||
|
ClonePrim(op_run_info);
|
||||||
|
}
|
||||||
|
if (!input_object->isa<ValueSequence>()) {
|
||||||
|
auto dyn_v = GetValue<const std::vector<int64_t>>(op_prim->GetAttr(kAttrDynInputSizes));
|
||||||
|
(void)dyn_v.emplace_back(-1);
|
||||||
|
op_prim->set_attr(kAttrDynInputSizes, MakeValue(dyn_v));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
op_prim->EndRecordAddAttr();
|
op_prim->EndRecordAddAttr();
|
||||||
|
|
|
@ -229,5 +229,5 @@ def bprop_scalar_not(x, out, dout):
|
||||||
|
|
||||||
@bprops.register("TensorMove")
|
@bprops.register("TensorMove")
|
||||||
def bprop_tensor_move(x, out, dout):
|
def bprop_tensor_move(x, out, dout):
|
||||||
"""Backpropagator for primitive `mutable`."""
|
"""Backpropagator for primitive `TensorMove`."""
|
||||||
return (dout,)
|
return (dout,)
|
||||||
|
|
|
@ -995,6 +995,7 @@ class RandomShuffle(Primitive):
|
||||||
def __init__(self, seed=0, seed2=0):
|
def __init__(self, seed=0, seed2=0):
|
||||||
"""Initialize RandomShuffle"""
|
"""Initialize RandomShuffle"""
|
||||||
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
|
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
|
||||||
|
self.add_prim_attr("side_effect_hidden", True)
|
||||||
Validator.check_non_negative_int(seed, "seed", self.name)
|
Validator.check_non_negative_int(seed, "seed", self.name)
|
||||||
Validator.check_non_negative_int(seed2, "seed2", self.name)
|
Validator.check_non_negative_int(seed2, "seed2", self.name)
|
||||||
|
|
||||||
|
|
|
@ -213,9 +213,9 @@ FuncGraphManagerPtr Make_Manager(int64_t condition = 0) {
|
||||||
/// Description:
|
/// Description:
|
||||||
/// Expectation: the python path is right
|
/// Expectation: the python path is right
|
||||||
TEST_F(TestStepParallel, GetPythonPath1) {
|
TEST_F(TestStepParallel, GetPythonPath1) {
|
||||||
OperatorName operator_name = "AllReduce";
|
const char *operator_name = "AllReduce";
|
||||||
const std::string expect = "mindspore.ops.operations";
|
const std::string expect = "mindspore.ops.operations";
|
||||||
auto temp = parallel::GetOpPythonPath(operator_name);
|
std::string temp = parallel::GetOpPythonPath(operator_name);
|
||||||
ASSERT_EQ(temp, expect);
|
ASSERT_EQ(temp, expect);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -223,9 +223,9 @@ TEST_F(TestStepParallel, GetPythonPath1) {
|
||||||
/// Description:
|
/// Description:
|
||||||
/// Expectation: the python path is right
|
/// Expectation: the python path is right
|
||||||
TEST_F(TestStepParallel, GetPythonPath2) {
|
TEST_F(TestStepParallel, GetPythonPath2) {
|
||||||
OperatorName operator_name = "Add";
|
const char *operator_name = "Add";
|
||||||
const std::string expect = "mindspore.ops.operations";
|
const std::string expect = "mindspore.ops.operations";
|
||||||
auto temp = parallel::GetOpPythonPath(operator_name);
|
std::string temp = parallel::GetOpPythonPath(operator_name);
|
||||||
ASSERT_EQ(temp, expect);
|
ASSERT_EQ(temp, expect);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue