forked from mindspore-Ecosystem/mindspore
!46504 Fix bug for PyNative
Merge pull request !46504 from zjun/fix_high
This commit is contained in:
commit
b34dd895f3
|
@ -375,6 +375,17 @@ void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler, con
|
|||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
args_tuple_num = make_tuple->inputs().size() - 1;
|
||||
continue;
|
||||
} else if (input_node->isa<Parameter>()) {
|
||||
auto param = input_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
auto abs = param->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
if (abs->isa<abstract::AbstractTuple>() && !abs->isa<abstract::AbstractSparseTensor>()) {
|
||||
auto abs_tuple = abs->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(abs_tuple);
|
||||
args_tuple_num = abs_tuple->elements().size();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Hook single-input or single-output.
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
#include "utils/profile.h"
|
||||
#include "include/common/utils/primitive_utils.h"
|
||||
#include "pipeline/jit/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ad {
|
||||
namespace {
|
||||
|
@ -43,8 +44,9 @@ constexpr char kAttrOnesLikeCOO[] = "ones_like_coo_node";
|
|||
enum class SpecialType { kZerosLikeType = 0, kOnesLikeType = 1 };
|
||||
const std::map<SpecialType, std::shared_ptr<Primitive>> kValueType{{SpecialType::kZerosLikeType, prim::kPrimZerosLike},
|
||||
{SpecialType::kOnesLikeType, prim::kPrimOnesLike}};
|
||||
const std::vector<PrimitivePtr> kGradBlackList{prim::kPrimMakeTuple, prim::kPrimTupleGetItem, prim::kPrimStopGradient,
|
||||
prim::kPrimUpdateState, prim::kPrimNPUAllocFloatStatus};
|
||||
const std::vector<PrimitivePtr> kGradBlackList{
|
||||
prim::kPrimMakeTuple, prim::kPrimTupleGetItem, prim::kPrimStopGradient, prim::kPrimUpdateState,
|
||||
prim::kPrimNPUAllocFloatStatus, prim::kPrimNPUGetFloatStatus, prim::kPrimNPUClearFloatStatus};
|
||||
AnfNodePtr BuildSpecialLikeValue(const FuncGraphPtr &tape, const ValuePtr &value, const SpecialType &type);
|
||||
void ClearDeviceAddress(const ValuePtr &value) {
|
||||
std::vector<tensor::TensorPtr> tensors;
|
||||
|
@ -55,26 +57,6 @@ void ClearDeviceAddress(const ValuePtr &value) {
|
|||
}
|
||||
}
|
||||
|
||||
ValuePtr FilterSensValues(const ValuePtr &value) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<tensor::Tensor>() || value->isa<tensor::COOTensor>() || value->isa<tensor::CSRTensor>()) {
|
||||
return value;
|
||||
} else if (value->isa<ValueSequence>()) {
|
||||
std::vector<ValuePtr> value_list;
|
||||
auto value_seq = value->cast<ValueSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_seq);
|
||||
for (auto filter_value : value_seq->value()) {
|
||||
if (FilterSensValues(filter_value) != nullptr) {
|
||||
(void)value_list.emplace_back(filter_value);
|
||||
}
|
||||
}
|
||||
return std::make_shared<ValueTuple>(value_list);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "value type: " << value->ToString();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
bool IsPrimNeedGrad(const PrimitivePtr &prim) {
|
||||
for (const auto &no_need_grad_prim : kGradBlackList) {
|
||||
if (IsPrimitiveEquals(prim, no_need_grad_prim)) {
|
||||
|
@ -234,6 +216,7 @@ bool IsZerosLikeNode(const AnfNodePtr &node) {
|
|||
}
|
||||
|
||||
FuncGraphPtr OptimizeBpropBuilder(const FuncGraphPtr &bprop_func_graph) {
|
||||
pynative::PyNativeAlgo::Common::DumpGraphIR("bprop_builder_before_opt.ir", bprop_func_graph);
|
||||
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
|
||||
resource->set_func_graph(bprop_func_graph);
|
||||
auto manager = resource->manager();
|
||||
|
@ -243,6 +226,34 @@ FuncGraphPtr OptimizeBpropBuilder(const FuncGraphPtr &bprop_func_graph) {
|
|||
pynative::PyNativeAlgo::Common::DumpGraphIR("bprop_builder_after_opt.ir", after_opt_bg);
|
||||
return after_opt_bg;
|
||||
}
|
||||
|
||||
bool IsOutputBothEmpty(const AnfNodePtr &inputs_grad, const AnfNodePtr &weights_grad) {
|
||||
if (!inputs_grad->isa<CNode>() || !weights_grad->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto inputs_grad_cnode = inputs_grad->cast<CNodePtr>();
|
||||
auto weights_grad_cnode = weights_grad->cast<CNodePtr>();
|
||||
if (!IsPrimitiveCNode(inputs_grad_cnode, prim::kPrimMakeTuple) ||
|
||||
!IsPrimitiveCNode(weights_grad_cnode, prim::kPrimMakeTuple)) {
|
||||
return false;
|
||||
}
|
||||
constexpr int kEmptyTupeSize = 1;
|
||||
if (inputs_grad_cnode->size() != kEmptyTupeSize || weights_grad_cnode->size() != kEmptyTupeSize) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtr GenerateEmptyTupleValue() {
|
||||
std::vector<ValuePtr> value_list;
|
||||
auto inputs_value = std::make_shared<ValueTuple>(value_list);
|
||||
auto weights_value = std::make_shared<ValueTuple>(value_list);
|
||||
std::vector<ValuePtr> tuple_list{inputs_value, weights_value};
|
||||
auto tuple_value = std::make_shared<ValueTuple>(tuple_list);
|
||||
auto tuple_value_node = NewValueNode(tuple_value);
|
||||
tuple_value_node->set_abstract(tuple_value->ToAbstract());
|
||||
return tuple_value_node;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr FunctionNode::HyperAdd(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
|
||||
|
@ -302,14 +313,15 @@ void FunctionNode::ReplaceEdges() {
|
|||
AutoGradCellImpl::AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std::vector<ValuePtr> &input_param_values)
|
||||
: tape_(std::make_shared<FuncGraph>()), cell_inputs_(cell_inputs) {
|
||||
tape_->debug_info()->set_name("grad_top");
|
||||
MS_LOG(DEBUG) << "Start AutoGradCellImpl, cell_inputs size: " << cell_inputs.size();
|
||||
for (size_t i = 0; i < cell_inputs.size(); ++i) {
|
||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(cell_inputs[i]->debug_info()));
|
||||
auto parameter = tape_->add_parameter();
|
||||
parameter->set_abstract(input_param_values[i]->ToAbstract()->Broaden());
|
||||
auto zeros_like_dout = BuildZerosLikeNode(tape_, input_param_values[i]);
|
||||
auto func_node = std::make_shared<FunctionNode>(tape_, zeros_like_dout);
|
||||
auto input_adjoint = std::make_shared<VariableNode>(func_node, input_param_values[i]);
|
||||
anfnode_to_variable_adjoint_.insert(std::make_pair(cell_inputs[i], input_adjoint));
|
||||
auto input_adjoint = std::make_shared<VariableAdjoint>(func_node, input_param_values[i]);
|
||||
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(cell_inputs[i], input_adjoint));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -329,12 +341,14 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
|
|||
ClearDeviceAddress(cloned_value);
|
||||
AnfNodePtr dout = BuildSpecialLikeValue(tape_, cloned_value, SpecialType::kZerosLikeType);
|
||||
auto fn = std::make_shared<FunctionNode>(tape_, dout);
|
||||
auto variable_adjoint = std::make_shared<VariableNode>(fn, cloned_value);
|
||||
auto variable_adjoint = std::make_shared<VariableAdjoint>(fn, cloned_value);
|
||||
if (!grad_param->grad_by_value) {
|
||||
BuildKNode(grad_param, variable_adjoint);
|
||||
need_do_manager_replace_ = true;
|
||||
}
|
||||
CNodePtr input_node = ConstructBpropGraphInput(grad_param, dout);
|
||||
CNodePtr input_node = ConstructBpropGraphInput(grad_param, dout, variable_adjoint);
|
||||
MS_LOG(DEBUG) << "Construct input cnode: " << input_node->DebugString();
|
||||
// Gradient outputs
|
||||
std::vector<CNodePtr> outputs;
|
||||
#ifndef ENABLE_TEST
|
||||
if (IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook)) {
|
||||
|
@ -342,7 +356,7 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
|
|||
} else {
|
||||
mindspore::BuildBprop(input_node, &outputs, &users_);
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(DEBUG) << "expander has no bprop of this prim: " << grad_param->cnode->DebugString();
|
||||
MS_LOG(DEBUG) << "Expander has no bprop of this prim: " << grad_param->cnode->DebugString();
|
||||
BuildCustomBpropCNode(input_node, &outputs);
|
||||
}
|
||||
}
|
||||
|
@ -356,11 +370,11 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
|
|||
if (!outputs.empty()) {
|
||||
UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "this op has not custom bprop: " << grad_param->cnode->DebugString();
|
||||
MS_LOG(DEBUG) << "This op has not custom bprop: " << grad_param->cnode->DebugString();
|
||||
variable_adjoint->set_is_fake_bprop(true);
|
||||
variable_adjoint->set_fake_prim_name(prim->name());
|
||||
}
|
||||
anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
|
||||
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
|
||||
// record last_node for brackpropagate
|
||||
last_node_ = grad_param->cnode;
|
||||
return true;
|
||||
|
@ -368,9 +382,9 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
|
|||
|
||||
bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
|
||||
MS_EXCEPTION_IF_NULL(grad_param);
|
||||
|
||||
AnfNodePtrList args_node_list;
|
||||
CNodePtr bprop_cnode = nullptr;
|
||||
AnfNodePtr k_node = nullptr;
|
||||
AnfNodePtr dout = nullptr;
|
||||
if (grad_param->grad_by_value) {
|
||||
for (size_t i = 0; i < grad_param->op_args.size(); ++i) {
|
||||
|
@ -391,7 +405,11 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
|
|||
BuildKNodeListFromPrimalCNode(grad_param->cnode, grad_param->op_args, &args_node_list);
|
||||
bprop_cnode = GetBPropFromFProp(grad_param->fprop_fg, args_node_list, grad_param->out, &dout);
|
||||
}
|
||||
|
||||
auto fn = std::make_shared<FunctionNode>(tape_, dout);
|
||||
auto variable_adjoint = std::make_shared<VariableAdjoint>(fn, grad_param->out);
|
||||
if (!grad_param->grad_by_value) {
|
||||
BuildKNode(grad_param, variable_adjoint);
|
||||
}
|
||||
std::vector<CNodePtr> outputs;
|
||||
for (size_t i = 1; i < grad_param->cnode->size(); ++i) {
|
||||
// bprop_app[0] env
|
||||
|
@ -399,11 +417,8 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
|
|||
din->set_abstract(grad_param->op_args[i - 1]->ToAbstract()->Broaden());
|
||||
(void)outputs.emplace_back(din);
|
||||
}
|
||||
auto fn = std::make_shared<FunctionNode>(tape_, dout);
|
||||
auto variable_adjoint = std::make_shared<VariableNode>(fn, grad_param->out);
|
||||
variable_adjoint->set_k_node(k_node);
|
||||
UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args);
|
||||
anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
|
||||
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
|
||||
need_do_manager_replace_ = true;
|
||||
return true;
|
||||
}
|
||||
|
@ -426,7 +441,7 @@ CNodePtr AutoGradCellImpl::GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const
|
|||
auto get_bprop =
|
||||
bprop_builder->NewCNode({NewValueNode(prim::kPrimTupleGetItem), fprop_app, NewValueNode(static_cast<int64_t>(1))});
|
||||
|
||||
// Get graph after optimize
|
||||
// Get bprop from fprop_fg, it is 2th output of fprop_fg
|
||||
AnfNodePtrList node_list{get_bprop};
|
||||
auto dout = bprop_builder->add_parameter();
|
||||
MS_EXCEPTION_IF_NULL(out);
|
||||
|
@ -434,12 +449,15 @@ CNodePtr AutoGradCellImpl::GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const
|
|||
(void)node_list.emplace_back(dout);
|
||||
auto call_bprop = bprop_builder->NewCNode(node_list);
|
||||
bprop_builder->set_output(call_bprop);
|
||||
|
||||
// Call pass for optimize graph, such as inline
|
||||
auto after_opt_fg = OptimizeBpropBuilder(bprop_builder);
|
||||
|
||||
// Call by tape_
|
||||
MS_EXCEPTION_IF_NULL(tape_dout);
|
||||
*tape_dout = BuildZerosLikeNode(tape_, out);
|
||||
(void)bprop_builder_inputs.emplace_back(*tape_dout);
|
||||
bprop_builder_inputs.insert(bprop_builder_inputs.cbegin(), NewValueNode(after_opt_fg));
|
||||
(void)bprop_builder_inputs.insert(bprop_builder_inputs.cbegin(), NewValueNode(after_opt_fg));
|
||||
get_bprop = tape_->NewCNode(bprop_builder_inputs);
|
||||
// tape_dout is set by next op
|
||||
AddUser(*tape_dout, get_bprop, bprop_builder_inputs.size() - 1);
|
||||
|
@ -451,7 +469,7 @@ void AutoGradCellImpl::UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node,
|
|||
MS_EXCEPTION_IF_NULL(sens_out);
|
||||
MS_LOG(DEBUG) << "Real output node of top cell is " << output_node->DebugString();
|
||||
last_node_ = output_node;
|
||||
sens_value_ = FilterSensValues(sens_out);
|
||||
sens_value_ = sens_out;
|
||||
}
|
||||
|
||||
FuncGraphPtr AutoGradCellImpl::Finish(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
|
||||
|
@ -463,45 +481,46 @@ FuncGraphPtr AutoGradCellImpl::Finish(const AnfNodePtrList &weights, const std::
|
|||
if (!last_node_->isa<ValueNode>() && !last_node_->isa<Parameter>()) {
|
||||
(void)BackPropagate();
|
||||
}
|
||||
// Return the gradient;
|
||||
if (grad_attr.get_by_position && grad_position.empty()) {
|
||||
MS_LOG(EXCEPTION) << "grad_position should not be empty when grad by position!";
|
||||
}
|
||||
|
||||
SetOutput(weights, grad_position, grad_attr);
|
||||
// Replace Parameter of primal funcgraph with parameter of tape_;
|
||||
// Replace Parameter of primal funcgraph with parameter of tape_;
|
||||
ReplacePrimalParameter(weights, grad_attr.has_sens);
|
||||
pynative::PyNativeAlgo::Common::DumpGraphIR("before_final_opt.ir", tape_);
|
||||
return tape_;
|
||||
}
|
||||
|
||||
CNodePtr AutoGradCellImpl::ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout) {
|
||||
CNodePtr AutoGradCellImpl::ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout,
|
||||
const VariableAdjointPtr &variable_adjoint) {
|
||||
MS_EXCEPTION_IF_NULL(grad_param);
|
||||
std::vector<AnfNodePtr> node_list;
|
||||
(void)node_list.emplace_back(grad_param->cnode->input(0));
|
||||
auto out_abs = grad_param->out->ToAbstract()->Broaden();
|
||||
if (grad_param->grad_by_value) {
|
||||
for (size_t i = 0; i < grad_param->op_args.size(); ++i) {
|
||||
const auto &v = grad_param->op_args[i];
|
||||
auto node = grad_param->cnode->input(i + 1);
|
||||
if (node->isa<Parameter>()) {
|
||||
node_list.emplace_back(node);
|
||||
node->set_abstract(v->ToAbstract());
|
||||
(void)node_list.emplace_back(node);
|
||||
node->set_abstract(v->ToAbstract()->Broaden());
|
||||
continue;
|
||||
}
|
||||
auto v_node = NewValueNode(grad_param->op_args[i]);
|
||||
v_node->set_abstract(grad_param->op_args[i]->ToAbstract());
|
||||
node_list.emplace_back(v_node);
|
||||
v_node->set_abstract(grad_param->op_args[i]->ToAbstract()->Broaden());
|
||||
(void)node_list.emplace_back(v_node);
|
||||
}
|
||||
// Set out
|
||||
auto out_node = NewValueNode(grad_param->out);
|
||||
out_node->set_abstract(out_abs);
|
||||
(void)node_list.emplace_back(out_node);
|
||||
} else {
|
||||
// Input is a Parameter or cnode, not a value node
|
||||
BuildKNodeListFromPrimalCNode(grad_param->cnode, grad_param->op_args, &node_list);
|
||||
// Set out
|
||||
MS_EXCEPTION_IF_NULL(variable_adjoint);
|
||||
(void)node_list.emplace_back(variable_adjoint->k_node());
|
||||
}
|
||||
auto out_node = NewValueNode(grad_param->out);
|
||||
auto out_abs = grad_param->out->ToAbstract()->Broaden();
|
||||
out_node->set_abstract(out_abs);
|
||||
// set out
|
||||
node_list.emplace_back(out_node);
|
||||
// set dout
|
||||
node_list.emplace_back(dout);
|
||||
// Set dout
|
||||
(void)node_list.emplace_back(dout);
|
||||
auto input_node = tape_->NewCNode(node_list);
|
||||
input_node->set_abstract(out_abs);
|
||||
return input_node;
|
||||
|
@ -511,7 +530,7 @@ void AutoGradCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, cons
|
|||
std::vector<AnfNodePtr> *const node_list) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
MS_LOG(DEBUG) << "Find input knode of node " << cnode->input(i)->DebugString();
|
||||
MS_LOG(DEBUG) << "Get knode for node " << cnode->input(i)->DebugString();
|
||||
if (cnode->input(i)->isa<CNode>()) {
|
||||
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(cnode->input(i));
|
||||
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
|
||||
|
@ -520,13 +539,13 @@ void AutoGradCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, cons
|
|||
MS_EXCEPTION_IF_NULL(input_adjoint_iter->second->k_node());
|
||||
(void)node_list->emplace_back(input_adjoint_iter->second->k_node());
|
||||
} else {
|
||||
cnode->input(i)->set_abstract(op_args[i - 1]->ToAbstract());
|
||||
cnode->input(i)->set_abstract(op_args[i - 1]->ToAbstract()->Broaden());
|
||||
(void)node_list->emplace_back(cnode->input(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param, const VariableNodePtr &VariableNode) {
|
||||
void AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param, const VariableAdjointPtr &variable_adjoint) {
|
||||
MS_EXCEPTION_IF_NULL(grad_param);
|
||||
AnfNodePtrList node_list;
|
||||
for (size_t i = 0; i < grad_param->cnode->inputs().size(); ++i) {
|
||||
|
@ -534,7 +553,8 @@ void AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param, const Variable
|
|||
}
|
||||
auto k_node = tape_->NewCNode(node_list);
|
||||
k_node->set_abstract(grad_param->out->ToAbstract()->Broaden());
|
||||
VariableNode->set_k_node(k_node);
|
||||
variable_adjoint->set_k_node(k_node);
|
||||
MS_LOG(DEBUG) << "Build knode " << k_node->DebugString();
|
||||
}
|
||||
|
||||
AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const AnfNodePtr &input_node) {
|
||||
|
@ -542,7 +562,7 @@ AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const AnfNodePtr &input_nod
|
|||
if (input_node->isa<CNode>()) {
|
||||
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(input_node);
|
||||
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
|
||||
MS_LOG(EXCEPTION) << "cannot find input in adjoint map, inp: " << input_node->DebugString();
|
||||
MS_LOG(EXCEPTION) << "Cannot find input in adjoint map, inp: " << input_node->DebugString();
|
||||
}
|
||||
return input_adjoint_iter->second->k_node();
|
||||
} else {
|
||||
|
@ -561,8 +581,9 @@ void AutoGradCellImpl::UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr
|
|||
MS_LOG(EXCEPTION) << "The size of dins is not same as op_args";
|
||||
}
|
||||
for (size_t i = 0; i < op_args.size(); ++i) {
|
||||
auto node = cnode->input(i + 1);
|
||||
auto din = dins[i];
|
||||
const auto &node = cnode->input(i + 1);
|
||||
const auto &din = dins[i];
|
||||
MS_LOG(DEBUG) << "Node " << node->DebugString() << ", din " << din->DebugString();
|
||||
UpdateNextEdges(fn, node, din, op_args[i]);
|
||||
}
|
||||
}
|
||||
|
@ -613,29 +634,27 @@ void AutoGradCellImpl::UpdateNextEdges(const FunctionNodePtr &fn, const AnfNodeP
|
|||
AddParameterNode(param, tensor);
|
||||
UpdateNextEdges(fn, node, din, op_arg);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "it is not a cnode: " << node->DebugString();
|
||||
MS_LOG(DEBUG) << "It is not a cnode or parameter: " << node->DebugString();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::BuildForwardLastNode() {
|
||||
MS_EXCEPTION_IF_NULL(last_node_);
|
||||
if (last_node_->isa<ValueNode>() ||
|
||||
anfnode_to_variable_adjoint_.find(last_node_) != anfnode_to_variable_adjoint_.end()) {
|
||||
return;
|
||||
}
|
||||
if (anfnode_to_variable_adjoint_.find(last_node_) == anfnode_to_variable_adjoint_.end()) {
|
||||
auto zeros_like_node = BuildZerosLikeNode(tape_, sens_value_);
|
||||
auto fn = std::make_shared<FunctionNode>(tape_, zeros_like_node);
|
||||
// If last_node is a maketuple or tuplegetitem, need update next edges,
|
||||
// if last_node is parameter, not need to update next edges.
|
||||
if (last_node_->isa<CNode>()) {
|
||||
UpdateNextEdges(fn, last_node_, zeros_like_node, sens_value_);
|
||||
}
|
||||
auto input_adjoint = std::make_shared<VariableNode>(fn, sens_value_);
|
||||
anfnode_to_variable_adjoint_.insert(std::make_pair(last_node_, input_adjoint));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "unprocessed node" << last_node_->DebugString();
|
||||
MS_LOG(DEBUG) << "Process last node info " << last_node_->DebugString();
|
||||
auto zeros_like_node = BuildZerosLikeNode(tape_, sens_value_);
|
||||
auto fn = std::make_shared<FunctionNode>(tape_, zeros_like_node);
|
||||
// If last_node is a maketuple or tuplegetitem, need update next edges,
|
||||
// if last_node is parameter, not need to update next edges.
|
||||
if (last_node_->isa<CNode>()) {
|
||||
UpdateNextEdges(fn, last_node_, zeros_like_node, sens_value_);
|
||||
}
|
||||
auto input_adjoint = std::make_shared<VariableAdjoint>(fn, sens_value_);
|
||||
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(last_node_, input_adjoint));
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::AddParameterNode(const AnfNodePtr ¶meter, const ValuePtr &tensor) {
|
||||
|
@ -643,9 +662,9 @@ void AutoGradCellImpl::AddParameterNode(const AnfNodePtr ¶meter, const Value
|
|||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto zeros_like_dout = BuildZerosLikeNode(tape_, tensor);
|
||||
auto func_node = std::make_shared<FunctionNode>(tape_, zeros_like_dout);
|
||||
auto input_adjoint = std::make_shared<VariableNode>(func_node, tensor);
|
||||
anfnode_to_variable_adjoint_.insert(std::make_pair(parameter, input_adjoint));
|
||||
weights_.push_back(parameter);
|
||||
auto input_adjoint = std::make_shared<VariableAdjoint>(func_node, tensor);
|
||||
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(parameter, input_adjoint));
|
||||
(void)weights_used_in_graph_.emplace_back(parameter);
|
||||
}
|
||||
|
||||
AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePtr &out_value, const ValuePtr &sub_value,
|
||||
|
@ -653,8 +672,8 @@ AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePt
|
|||
MS_EXCEPTION_IF_NULL(out_value);
|
||||
MS_EXCEPTION_IF_NULL(sub_value);
|
||||
MS_EXCEPTION_IF_NULL(din);
|
||||
std::string out_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(out_value);
|
||||
std::string sub_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(sub_value);
|
||||
const auto &out_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(out_value);
|
||||
const auto &sub_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(sub_value);
|
||||
if (out_value_id == sub_value_id) {
|
||||
return din;
|
||||
} else if (out_value->isa<tensor::Tensor>()) {
|
||||
|
@ -668,13 +687,13 @@ AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePt
|
|||
}
|
||||
auto value_seq = out_value->cast<ValueSequencePtr>();
|
||||
int index = -1;
|
||||
for (auto value : value_seq->value()) {
|
||||
for (const auto &value : value_seq->value()) {
|
||||
auto real_din = GetRealDin(fn, value, sub_value, din);
|
||||
(void)inputs.emplace_back(real_din);
|
||||
|
||||
// if exist din == fake_dout, we record it in user vector
|
||||
if (din == fn->fake_dout() && real_din == din) {
|
||||
index = inputs.size() - 1;
|
||||
index = static_cast<int>(inputs.size()) - 1;
|
||||
}
|
||||
}
|
||||
auto new_din = tape_->NewCNode(inputs);
|
||||
|
@ -700,7 +719,7 @@ void AutoGradCellImpl::BuildBPropCutCNode(const CNodePtr &cnode, std::vector<CNo
|
|||
prim_py->AddBpropCutPrim(bprop_cut);
|
||||
if (prim->HasAttr("cell_id")) {
|
||||
auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
|
||||
if (cell_id != "") {
|
||||
if (!cell_id.empty()) {
|
||||
(void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
|
||||
(void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id));
|
||||
}
|
||||
|
@ -719,12 +738,11 @@ void AutoGradCellImpl::BuildBPropCutCNode(const CNodePtr &cnode, std::vector<CNo
|
|||
if (i < args_size) {
|
||||
auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), output, NewValueNode(SizeToLong(i - 1))});
|
||||
din->set_abstract(cnode->input(i)->abstract()->Broaden());
|
||||
outputs->emplace_back(din);
|
||||
(void)outputs->emplace_back(din);
|
||||
(void)abs.emplace_back(din->abstract());
|
||||
}
|
||||
}
|
||||
output->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
|
||||
return;
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs) {
|
||||
|
@ -755,11 +773,7 @@ void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<
|
|||
}
|
||||
|
||||
void AutoGradCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg) {
|
||||
MS_EXCEPTION_IF_NULL(last_node_);
|
||||
MS_LOG(DEBUG) << "Last node info " << last_node_->DebugString();
|
||||
|
||||
BuildForwardLastNode();
|
||||
|
||||
// Add sens parameter
|
||||
ParameterPtr sens_param = nullptr;
|
||||
if (has_sens_arg) {
|
||||
|
@ -769,8 +783,9 @@ void AutoGradCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool has
|
|||
}
|
||||
|
||||
// update dout for dout
|
||||
MS_EXCEPTION_IF_NULL(last_node_);
|
||||
if (anfnode_to_variable_adjoint_.find(last_node_) != anfnode_to_variable_adjoint_.end()) {
|
||||
auto variable = anfnode_to_variable_adjoint_.at(last_node_);
|
||||
const auto &variable = anfnode_to_variable_adjoint_.at(last_node_);
|
||||
if (has_sens_arg && sens_param != nullptr) {
|
||||
variable->fn()->UpdateAccumulativeDout(sens_param);
|
||||
} else {
|
||||
|
@ -782,19 +797,19 @@ void AutoGradCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool has
|
|||
need_grad_weights_.clear();
|
||||
for (const auto &weight : weights) {
|
||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(weight->debug_info()));
|
||||
auto p = tape_->add_parameter();
|
||||
(void)need_grad_weights_.emplace(weight);
|
||||
auto input_w = weight->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_w);
|
||||
// Use name to match weight parameter in high order
|
||||
auto default_param = input_w->default_param();
|
||||
p->set_name(input_w->name());
|
||||
p->set_default_param(default_param);
|
||||
p->set_abstract(default_param->ToAbstract()->Broaden());
|
||||
auto t = pynative::PyNativeAlgo::Common::GetTensorFromParam(weight);
|
||||
(void)need_grad_weights_.emplace(t->id());
|
||||
auto p = tape_->add_parameter();
|
||||
auto param = weight->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
p->set_name(param->name());
|
||||
p->set_default_param(t);
|
||||
p->set_abstract(t->ToAbstract()->Broaden());
|
||||
}
|
||||
}
|
||||
|
||||
OrderedMap<AnfNodePtr, VariableNodePtr>::reverse_iterator AutoGradCellImpl::GetLastNodeReverseIter() {
|
||||
OrderedMap<AnfNodePtr, VariableAdjointPtr>::reverse_iterator AutoGradCellImpl::GetLastNodeReverseIter() {
|
||||
for (auto iter = anfnode_to_variable_adjoint_.rbegin(); iter != anfnode_to_variable_adjoint_.rend(); ++iter) {
|
||||
if (!iter->first->isa<CNode>()) {
|
||||
continue;
|
||||
|
@ -810,33 +825,40 @@ OrderedMap<AnfNodePtr, VariableNodePtr>::reverse_iterator AutoGradCellImpl::GetL
|
|||
|
||||
void AutoGradCellImpl::BackPropagate() {
|
||||
const auto &last_node_reverse_iter = GetLastNodeReverseIter();
|
||||
bool has_primc = false;
|
||||
for (auto iter = last_node_reverse_iter; iter != anfnode_to_variable_adjoint_.rend(); ++iter) {
|
||||
MS_LOG(DEBUG) << "BackPropagate cnode: " << iter->first->DebugString();
|
||||
auto variable = iter->second;
|
||||
const auto &variable = iter->second;
|
||||
if (!variable->is_need_propagate()) {
|
||||
MS_LOG(DEBUG) << "No need grad";
|
||||
continue;
|
||||
}
|
||||
if (variable->is_need_propagate() && variable->is_fake_bprop()) {
|
||||
MS_LOG(EXCEPTION) << variable->fake_prim_name() << " op has not corresponding bprop!";
|
||||
if (variable->is_fake_bprop()) {
|
||||
MS_LOG(WARNING) << variable->fake_prim_name() << " op has not corresponding bprop!";
|
||||
continue;
|
||||
}
|
||||
auto fn = variable->fn();
|
||||
if (!has_primc && iter->first->isa<CNode>() && GetCNodePrimitive(iter->first) != nullptr) {
|
||||
has_primc = true;
|
||||
}
|
||||
const auto &fn = variable->fn();
|
||||
// replace real dout to fake dout
|
||||
Replace(fn->fake_dout(), fn->RealDout());
|
||||
// replace edges which exist fake dout
|
||||
fn->ReplaceEdges();
|
||||
|
||||
auto &next_edges = fn->next_edges();
|
||||
const auto &next_edges = fn->next_edges();
|
||||
for (const auto &next_edge : next_edges) {
|
||||
auto node = next_edge.first;
|
||||
auto din = next_edge.second;
|
||||
const auto &node = next_edge.first;
|
||||
const auto &din = next_edge.second;
|
||||
if (anfnode_to_variable_adjoint_.find(node) == anfnode_to_variable_adjoint_.end()) {
|
||||
MS_LOG(EXCEPTION) << "current node not find corresponding node";
|
||||
MS_LOG(EXCEPTION) << "Current node not find corresponding node";
|
||||
}
|
||||
auto last_variable = anfnode_to_variable_adjoint_[node];
|
||||
last_variable->fn()->UpdateAccumulativeDout(din);
|
||||
last_variable->set_is_need_propagate(true);
|
||||
}
|
||||
}
|
||||
tape_->set_flag(kPrimCPrimPyMixed, has_primc && need_do_manager_replace_);
|
||||
}
|
||||
|
||||
AnfNodePtr AutoGradCellImpl::GetGradNodeByIndex(const AnfNodePtrList &node_list, size_t index) {
|
||||
|
@ -850,7 +872,7 @@ AnfNodePtr AutoGradCellImpl::GetGradNodeByIndex(const AnfNodePtrList &node_list,
|
|||
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
|
||||
// If weight is not used in the forward network, just return zeros_like() as dout.
|
||||
if (grad_node->isa<Parameter>()) {
|
||||
MS_LOG(WARNING) << "Weight does not participate in forward calculation, weight: " << grad_node->DebugString();
|
||||
MS_LOG(INFO) << "Weight does not participate in forward calculation, weight: " << grad_node->DebugString();
|
||||
auto w = grad_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(w);
|
||||
auto default_param = w->default_param();
|
||||
|
@ -920,34 +942,6 @@ AnfNodePtr AutoGradCellImpl::GetWeightGrad(bool grad_weights, const AnfNodePtrLi
|
|||
}
|
||||
}
|
||||
|
||||
bool AutoGradCellImpl::IsOutputBothEmpty(const AnfNodePtr &inputs_grad, const AnfNodePtr &weights_grad) const {
|
||||
if (!inputs_grad->isa<CNode>() || !weights_grad->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto inputs_grad_cnode = inputs_grad->cast<CNodePtr>();
|
||||
auto weights_grad_cnode = weights_grad->cast<CNodePtr>();
|
||||
if (!IsPrimitiveCNode(inputs_grad_cnode, prim::kPrimMakeTuple) ||
|
||||
!IsPrimitiveCNode(weights_grad_cnode, prim::kPrimMakeTuple)) {
|
||||
return false;
|
||||
}
|
||||
constexpr int kEmptyTupeSize = 1;
|
||||
if (inputs_grad_cnode->size() != kEmptyTupeSize || weights_grad_cnode->size() != kEmptyTupeSize) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtr AutoGradCellImpl::GenerateEmptyTupleValue() {
|
||||
std::vector<ValuePtr> value_list;
|
||||
auto inputs_value = std::make_shared<ValueTuple>(value_list);
|
||||
auto weights_value = std::make_shared<ValueTuple>(value_list);
|
||||
std::vector<ValuePtr> tuple_list{inputs_value, weights_value};
|
||||
auto tuple_value = std::make_shared<ValueTuple>(tuple_list);
|
||||
auto tuple_value_node = NewValueNode(tuple_value);
|
||||
tuple_value_node->set_abstract(tuple_value->ToAbstract());
|
||||
return tuple_value_node;
|
||||
}
|
||||
|
||||
void AutoGradCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
|
||||
const GradAttr &grad_attr) {
|
||||
auto inputs_grad_ret = GetInputGrad(grad_attr.grad_all_inputs, grad_attr.get_by_position, grad_position);
|
||||
|
@ -1012,8 +1006,8 @@ void AutoGradCellImpl::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new
|
|||
}
|
||||
|
||||
void AutoGradCellImpl::ElimateTupleGetItem() {
|
||||
for (auto iter = users_.begin(); iter != users_.end(); iter++) {
|
||||
auto old_node = iter->first;
|
||||
for (auto &user : users_) {
|
||||
auto old_node = user.first;
|
||||
if (!old_node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
|
@ -1034,6 +1028,7 @@ void AutoGradCellImpl::ElimateTupleGetItem() {
|
|||
void AutoGradCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg) {
|
||||
const auto ¶meters = tape_->parameters();
|
||||
auto cell_inputs_size = cell_inputs_.size();
|
||||
pynative::PyNativeAlgo::Common::DumpGraphIR("replace_param.ir", tape_);
|
||||
if (need_do_manager_replace_) {
|
||||
MS_LOG(DEBUG) << "Do parameter replace by manager";
|
||||
auto mng = MakeManager({tape_}, false);
|
||||
|
@ -1065,13 +1060,12 @@ void AutoGradCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, boo
|
|||
}
|
||||
}
|
||||
|
||||
for (auto &weight : weights_) {
|
||||
if (need_grad_weights_.find(weight) == need_grad_weights_.end()) {
|
||||
auto parameter = weight->cast<ParameterPtr>();
|
||||
const auto &input_value = parameter->default_param();
|
||||
MS_EXCEPTION_IF_NULL(input_value);
|
||||
auto value_node = NewValueNode(input_value);
|
||||
value_node->set_abstract(input_value->ToAbstract()->Broaden());
|
||||
for (auto &weight : weights_used_in_graph_) {
|
||||
auto t = pynative::PyNativeAlgo::Common::GetTensorFromParam(weight);
|
||||
if (need_grad_weights_.find(t->id()) == need_grad_weights_.end()) {
|
||||
MS_LOG(DEBUG) << "Convert " << weight->DebugString() << " to value node";
|
||||
auto value_node = NewValueNode(t);
|
||||
value_node->set_abstract(t->ToAbstract()->Broaden());
|
||||
Replace(weight, value_node);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -86,9 +86,9 @@ class FunctionNode {
|
|||
};
|
||||
using FunctionNodePtr = std::shared_ptr<FunctionNode>;
|
||||
|
||||
class VariableNode {
|
||||
class VariableAdjoint {
|
||||
public:
|
||||
VariableNode(const FunctionNodePtr &fn, const ValuePtr &out_value) : fn_(fn), out_value_(out_value) {}
|
||||
VariableAdjoint(const FunctionNodePtr &fn, const ValuePtr &out_value) : fn_(fn), out_value_(out_value) {}
|
||||
|
||||
ValuePtr out_value() const { return out_value_; }
|
||||
FunctionNodePtr fn() const { return fn_; }
|
||||
|
@ -114,7 +114,7 @@ class VariableNode {
|
|||
// K mapped cnode for primal CNode; primal CNode is owned by primal funcgraph, this is owned by tape_;
|
||||
AnfNodePtr k_node_{nullptr};
|
||||
};
|
||||
using VariableNodePtr = std::shared_ptr<VariableNode>;
|
||||
using VariableAdjointPtr = std::shared_ptr<VariableAdjoint>;
|
||||
|
||||
class AutoGradCellImpl {
|
||||
public:
|
||||
|
@ -143,10 +143,10 @@ class AutoGradCellImpl {
|
|||
// Top cell inputs
|
||||
AnfNodePtrList cell_inputs_;
|
||||
// These weights need to calculate gradient.
|
||||
mindspore::HashSet<AnfNodePtr> need_grad_weights_;
|
||||
mindspore::HashSet<std::string> need_grad_weights_;
|
||||
// Bprop dins of each variable or middle out
|
||||
OrderedMap<AnfNodePtr, VariableNodePtr> anfnode_to_variable_adjoint_;
|
||||
AnfNodePtrList weights_;
|
||||
OrderedMap<AnfNodePtr, VariableAdjointPtr> anfnode_to_variable_adjoint_;
|
||||
AnfNodePtrList weights_used_in_graph_;
|
||||
// Record cnode's input map for tape_
|
||||
UserType users_;
|
||||
// Flag for ms_funtcion and high order
|
||||
|
@ -156,7 +156,8 @@ class AutoGradCellImpl {
|
|||
std::vector<bool> GetNeedGradFlags(const CNodePtr &cnode);
|
||||
|
||||
// construct input as cnode for expander
|
||||
CNodePtr ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout);
|
||||
CNodePtr ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout,
|
||||
const VariableAdjointPtr &variable_adjoint);
|
||||
// Back propagate for one node;
|
||||
void UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr &cnode, const std::vector<CNodePtr> &dins,
|
||||
const ValuePtrList &op_args);
|
||||
|
@ -176,7 +177,7 @@ class AutoGradCellImpl {
|
|||
// Set sens and weights parameter nodes by user input info
|
||||
void SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg);
|
||||
// get last reverse iterator
|
||||
OrderedMap<AnfNodePtr, VariableNodePtr>::reverse_iterator GetLastNodeReverseIter();
|
||||
OrderedMap<AnfNodePtr, VariableAdjointPtr>::reverse_iterator GetLastNodeReverseIter();
|
||||
|
||||
void BackPropagate();
|
||||
// Set return node according to grad flag
|
||||
|
@ -184,14 +185,12 @@ class AutoGradCellImpl {
|
|||
AnfNodePtr GetGradNodeByIndex(const AnfNodePtrList &node_list, size_t index);
|
||||
AnfNodePtr GetInputGrad(bool grad_all_inputs, bool get_by_position, const std::vector<size_t> &grad_position);
|
||||
AnfNodePtr GetWeightGrad(bool grad_weights, const AnfNodePtrList &weights, bool weight_param_is_tuple);
|
||||
bool IsOutputBothEmpty(const AnfNodePtr &inputs_grad, const AnfNodePtr &weights_grad) const;
|
||||
AnfNodePtr GenerateEmptyTupleValue();
|
||||
void AddUser(const AnfNodePtr &node, const CNodePtr &user, size_t index);
|
||||
void Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
|
||||
void ElimateTupleGetItem();
|
||||
|
||||
// Fbprop
|
||||
void BuildKNode(const GradParamPtr &grad_param, const VariableNodePtr &VariableNode);
|
||||
void BuildKNode(const GradParamPtr &grad_param, const VariableAdjointPtr &variable_adjoint);
|
||||
void BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const ValuePtrList &op_args,
|
||||
std::vector<AnfNodePtr> *const node_list);
|
||||
AnfNodePtr BuildKNodeForCNodeInput(const AnfNodePtr &input_node);
|
||||
|
|
|
@ -34,10 +34,12 @@ namespace mindspore {
|
|||
/* namespace to support opt */
|
||||
namespace opt {
|
||||
namespace {
|
||||
const std::map<std::string, std::vector<std::string>> op2attrs = {{prim::kPrimBroadcastTo->name(), {kAttrShape}},
|
||||
{prim::kPrimReduceMax->name(), {kAttrKeepDims}},
|
||||
{prim::kPrimReduceMin->name(), {kAttrKeepDims}},
|
||||
{prim::kPrimReduceSum->name(), {kAttrKeepDims}}};
|
||||
const std::map<std::string, std::vector<std::string>> op2attrs = {
|
||||
{prim::kPrimBroadcastTo->name(), {kAttrShape}},
|
||||
{prim::kPrimReduceMax->name(), {kAttrKeepDims}},
|
||||
{prim::kPrimReduceMin->name(), {kAttrKeepDims}},
|
||||
{prim::kPrimReduceSum->name(), {kAttrKeepDims}},
|
||||
{prim::kPrimMatMul->name(), {kTransposeA, kTransposeB}}};
|
||||
}
|
||||
|
||||
bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
|
||||
|
@ -53,7 +55,7 @@ bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
|
|||
continue;
|
||||
}
|
||||
parallel::OperatorAttrs attrs;
|
||||
auto iter = op2attrs.find(primitive->name());
|
||||
const auto iter = op2attrs.find(primitive->name());
|
||||
if (iter != op2attrs.end()) {
|
||||
for (auto &attr : iter->second) {
|
||||
if (primitive->HasAttr(attr)) {
|
||||
|
|
|
@ -29,40 +29,32 @@ using mindspore::tensor::Tensor;
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
std::string GetOpPythonPath(const OperatorName &op_name) {
|
||||
// almost all ops are defined in two main paths
|
||||
const std::string ops_module = OP_PATH;
|
||||
const std::string inner_ops_module = INNER_OP_PATH;
|
||||
const std::string grad_ops_module = GRAD_OP_PATH;
|
||||
const std::string functional_op_module = FUNCTIONAL_OP_PATH;
|
||||
py::module mod = py::module::import(common::SafeCStr(ops_module));
|
||||
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
|
||||
py::module grad_mod = py::module::import(common::SafeCStr(grad_ops_module));
|
||||
py::module functional_mod = py::module::import(common::SafeCStr(functional_op_module));
|
||||
const char *GetOpPythonPath(const char *op_name) {
|
||||
static py::module inner_mod = py::module::import(INNER_OP_PATH);
|
||||
if (py::hasattr(inner_mod, op_name)) {
|
||||
return INNER_OP_PATH;
|
||||
}
|
||||
|
||||
if (py::hasattr(inner_mod, common::SafeCStr(op_name))) {
|
||||
return inner_ops_module;
|
||||
static py::module mod = py::module::import(OP_PATH);
|
||||
if (py::hasattr(mod, op_name)) {
|
||||
return OP_PATH;
|
||||
}
|
||||
if (py::hasattr(mod, common::SafeCStr(op_name))) {
|
||||
return ops_module;
|
||||
|
||||
static py::module grad_mod = py::module::import(GRAD_OP_PATH);
|
||||
if (py::hasattr(grad_mod, op_name)) {
|
||||
return GRAD_OP_PATH;
|
||||
}
|
||||
if (py::hasattr(grad_mod, common::SafeCStr(op_name))) {
|
||||
return grad_ops_module;
|
||||
|
||||
static py::module functional_mod = py::module::import(FUNCTIONAL_OP_PATH);
|
||||
if (!py::hasattr(functional_mod, op_name)) {
|
||||
MS_LOG(EXCEPTION) << OP_PATH << " and " << INNER_OP_PATH << " and " << GRAD_OP_PATH << " and " << FUNCTIONAL_OP_PATH
|
||||
<< " don't have op:" << op_name;
|
||||
}
|
||||
if (!py::hasattr(functional_mod, common::SafeCStr(op_name))) {
|
||||
MS_LOG(EXCEPTION) << ops_module << " and " << inner_ops_module << " and " << grad_ops_module << " and "
|
||||
<< functional_op_module << " don't have op:" << op_name;
|
||||
}
|
||||
return functional_op_module;
|
||||
return FUNCTIONAL_OP_PATH;
|
||||
}
|
||||
|
||||
ValuePtr CreateOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
|
||||
std::string op_path = GetOpPythonPath(op_name);
|
||||
py::module mod = py::module::import(common::SafeCStr(op_path));
|
||||
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
|
||||
MS_LOG(ERROR) << "Failure: op_path:" << op_path << " don't have attr " << op_name;
|
||||
return nullptr;
|
||||
}
|
||||
const auto op_path = GetOpPythonPath(op_name.c_str());
|
||||
std::vector<py::object> arg_list;
|
||||
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list),
|
||||
[](const Attr &attr) { return ValueToPyData(attr.second); });
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace mindspore {
|
|||
namespace parallel {
|
||||
const char USING_HASH_NAME[] = "USING_HASH_NAME";
|
||||
// Get the operator's path where the operator has be defined
|
||||
std::string GetOpPythonPath(const OperatorName &op_name);
|
||||
const char *GetOpPythonPath(const char *op_name);
|
||||
|
||||
// Init python operator Instance
|
||||
ValuePtr CreateOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name);
|
||||
|
|
|
@ -498,6 +498,8 @@ constexpr auto kAttrReshapeType = "reshape_type";
|
|||
constexpr auto kAttrAxis = "axis";
|
||||
constexpr auto kAttrAxes = "axes";
|
||||
constexpr auto kAttrKeepDims = "keep_dims";
|
||||
constexpr auto kTransposeA = "transpose_a";
|
||||
constexpr auto kTransposeB = "transpose_b";
|
||||
constexpr auto kAttrSkipMode = "skip_mode";
|
||||
constexpr auto kAttrShapeGamma = "shape_gamma";
|
||||
constexpr auto kAttrPerm = "perm";
|
||||
|
@ -718,6 +720,7 @@ constexpr auto kFlagNeedRenormalize = "need_renormalize";
|
|||
constexpr auto kFlagEnableZeroCopyInGraph = "enable_zero_copy_in_graph";
|
||||
constexpr auto kFlagUseDynamicShapeProcess = "use_dynamic_shape_process";
|
||||
constexpr auto kFlagEnableRunGraphBySingleOp = "enable_run_graph_by_single_op";
|
||||
constexpr auto kPrimCPrimPyMixed = "primc_primpy_mixed";
|
||||
// TODO(dsj): for ms_function running in graph_mode. should be delete later
|
||||
constexpr auto kAttrMSFunction = "ms_function_graph";
|
||||
|
||||
|
|
|
@ -198,6 +198,8 @@ FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &resource) {
|
|||
irpass.tuple_list_get_item_eliminator_,
|
||||
irpass.tuple_list_set_item_eliminator_,
|
||||
irpass.depend_value_elim_,
|
||||
irpass.switch_simplify_,
|
||||
irpass.ad_related_special_op_eliminate_,
|
||||
});
|
||||
OptPassGroupMap map({{"ad_final_opt", bg_final_opt}});
|
||||
if (pynative::PyNativeExecutor::GetInstance()->grad_executor()->need_renormalize()) {
|
||||
|
@ -250,6 +252,7 @@ FuncGraphPtr OptGradGraphPass(const ResourcePtr &resource) {
|
|||
WITH(MsProfile::GetProfile()->Step("bprop_graph_final_opt"))[&bprop_graph_final_opt, &func_graph]() {
|
||||
func_graph = bprop_graph_final_opt->step(func_graph, true);
|
||||
};
|
||||
func_graph = LiftingClone(func_graph);
|
||||
// Validate(func_graph);
|
||||
return func_graph;
|
||||
}
|
||||
|
|
|
@ -88,12 +88,17 @@ struct InputArgsInfo {
|
|||
bool has_custom_bprop;
|
||||
size_t input_size;
|
||||
std::string obj_id;
|
||||
|
||||
bool has_sens{false};
|
||||
bool use_dynamic_shape_process = false;
|
||||
bool grad_is_running{false};
|
||||
bool use_dynamic_shape_process{false};
|
||||
PrimitivePyPtr custom_bprp_prim{nullptr};
|
||||
ValuePtr out_value{nullptr};
|
||||
std::string cell_id;
|
||||
std::string already_run_cell_id;
|
||||
std::string input_args_id;
|
||||
// Cell unique id, cell_id + cell_order;
|
||||
std::string obj_order_id;
|
||||
size_t custom_bprop_cell_count = 0;
|
||||
size_t grad_order = 0;
|
||||
std::vector<std::string> input_arg_id_vec;
|
||||
|
|
|
@ -161,9 +161,9 @@ void ForwardExecutor::RunOpForward(const FrontendOpRunInfoPtr &op_run_info) {
|
|||
MS_LOG(DEBUG) << "RunOp name: " << op_run_info->base_op_run_info.op_name;
|
||||
// 1.Set cast for inputs
|
||||
SetCastForInputs(op_run_info);
|
||||
// 2. Infer output abstract
|
||||
// 2.Infer output abstract
|
||||
InferOutputAbstract(op_run_info);
|
||||
// 3. Run op with selected backend
|
||||
// 3.Run op with selected backend
|
||||
if (!op_run_info->output_get_by_infer_value) {
|
||||
GetOutput(op_run_info);
|
||||
}
|
||||
|
@ -171,7 +171,6 @@ void ForwardExecutor::RunOpForward(const FrontendOpRunInfoPtr &op_run_info) {
|
|||
MS_LOG(DEBUG) << "Grad flag is false";
|
||||
return;
|
||||
}
|
||||
|
||||
// Set forward output flag for release memory,
|
||||
// Because tensor address may change, it should set in main thread to ensure consistency.
|
||||
PyNativeAlgo::Common::SetForwardOutputFlag(op_run_info->out_value);
|
||||
|
@ -181,7 +180,10 @@ void ForwardExecutor::RunOpForward(const FrontendOpRunInfoPtr &op_run_info) {
|
|||
return;
|
||||
}
|
||||
// 4. Do op grad and record op info
|
||||
if (!is_ms_function_compiling_) {
|
||||
// If cell have custom bprop, no need do op grad. Otherwise, need do.
|
||||
// If ms function is complime, real run op must be not record
|
||||
bool is_custom_bprop = op_run_info->custom_bprop_cell_count <= 0;
|
||||
if (!is_ms_function_compiling_ && is_custom_bprop) {
|
||||
grad()->ProcessOpGradInfo(op_run_info);
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -82,6 +82,10 @@ class GradExecutor {
|
|||
inline void set_use_dynamic_shape_process(bool use_dynamic_shape_process) {
|
||||
use_dynamic_shape_process_ = use_dynamic_shape_process;
|
||||
}
|
||||
inline InputArgsInfoPtr top_input_args_info() const {
|
||||
MS_EXCEPTION_IF_NULL(top_input_args_info_);
|
||||
return top_input_args_info_;
|
||||
}
|
||||
|
||||
inline bool need_renormalize() const { return need_renormalize_; }
|
||||
inline void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); }
|
||||
|
@ -101,28 +105,24 @@ class GradExecutor {
|
|||
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &grad_hash_id,
|
||||
const py::args &args);
|
||||
TopCellInfoPtr GetTopCell(const std::string &already_run_cell_id);
|
||||
void GetPreRunTopCell(const py::object &obj);
|
||||
void ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
|
||||
void AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
|
||||
void EndGraphInner(const py::object &obj, const py::object &out, const py::args &args);
|
||||
void EndGraphImpl(const InputArgsInfoPtr &input_args_info);
|
||||
AnfNodePtr GetInput(const ValuePtr &v, const string &obj_id) const;
|
||||
void AsyncEndGraphImpl(const InputArgsInfoPtr input_args_info);
|
||||
AnfNodePtr GetParamInput(const ValuePtr &v, const std::string &id) const;
|
||||
void UpdateForwardTensorInfoInBpropGraph(const FrontendOpRunInfoPtr &op_run_info) const;
|
||||
void UpdatePreTensorInfo(const tensor::TensorPtr &new_tensor,
|
||||
const std::vector<tensor::TensorPtr> &pre_tensors) const;
|
||||
void ClearRes();
|
||||
void WorkerJoin() { async_executor_->WorkerJoin(); }
|
||||
|
||||
void CheckGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node = false,
|
||||
void CheckGraphDynamic(const CNodePtr &cnode, const size_t node_idx, bool is_ms_function_node = false,
|
||||
const std::string &graph_phase = "") const;
|
||||
|
||||
private:
|
||||
ForwardExecutorPtr forward() const;
|
||||
inline FuncGraphPtr curr_g() const { return top_cell()->fg(); }
|
||||
inline void PushHighOrderGraphStack(const TopCellInfoPtr &top_cell) { high_order_stack_.push(top_cell); }
|
||||
std::string GetCurCellOrder() const;
|
||||
void SetGradOrder(const std::string &cell_id);
|
||||
void SetGradOrder(const std::string &obj_id);
|
||||
void SaveOutputNodeMap(const std::string &obj_id, const FrontendOpRunInfoPtr &op_run_info,
|
||||
const CNodePtr &cnode) const;
|
||||
void DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode, const ValuePtr &op_out) const;
|
||||
|
@ -158,7 +158,6 @@ class GradExecutor {
|
|||
void HandleInputArgsForTopCell(const InputArgsInfoPtr &input_args_info, bool is_bprop_top) const;
|
||||
void InitResourceAndDfBuilder(const InputArgsInfoPtr &cell_info);
|
||||
void MakeNewTopGraph(const InputArgsInfoPtr &input_args_info);
|
||||
void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph) const;
|
||||
|
||||
// Manage resource when run grad process.
|
||||
bool IsBpropGraph(const std::string &cell_id) const;
|
||||
|
@ -166,6 +165,11 @@ class GradExecutor {
|
|||
InputArgsInfoPtr GetInputArgsInfo(const py::object &obj, const py::args &args);
|
||||
void NewGraphImpl(const InputArgsInfoPtr &input_args_info);
|
||||
void AsyncNewGraphImpl(const InputArgsInfoPtr &input_args_info);
|
||||
void EndGraphInner(const py::object &obj, const py::object &out, const py::args &args);
|
||||
void UpdateInputArgsInfo(const InputArgsInfoPtr &input_args_info, const py::object &obj, const py::object &out,
|
||||
const py::args &args);
|
||||
void EndGraphImpl(const InputArgsInfoPtr &input_args_info);
|
||||
void AsyncEndGraphImpl(const InputArgsInfoPtr &input_args_info);
|
||||
void SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const;
|
||||
void GetCustomBpropPrim(const py::object &obj, const py::args &args, const py::object &out,
|
||||
const InputArgsInfoPtr &input_args_info);
|
||||
|
@ -189,14 +193,10 @@ class GradExecutor {
|
|||
AnfNodePtr CreateTupleGetItemNode(const std::string &obj_id,
|
||||
const std::pair<AnfNodePtr, std::vector<int64_t>> &out) const;
|
||||
|
||||
void SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
|
||||
void SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, size_t node_idx, bool is_ms_function_node,
|
||||
const std::string &graph_phase) const;
|
||||
bool IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
|
||||
bool IsGraphDynamic(const CNodePtr &cnode, size_t node_idx, bool is_ms_function_node,
|
||||
const std::string &graph_phase) const;
|
||||
bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info,
|
||||
const std::vector<AnfNodePtr> &new_anf_inputs) const;
|
||||
bool IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info, const CNodePtr &new_cnode,
|
||||
bool is_ms_function_node, const std::string &graph_phase) const;
|
||||
bool grad_flag_{false};
|
||||
bool grad_is_running_{false};
|
||||
bool need_renormalize_{false};
|
||||
|
@ -204,16 +204,11 @@ class GradExecutor {
|
|||
mutable bool use_dynamic_shape_process_{false};
|
||||
mutable bool is_cell_id_in_dynamic_detect_nodes_map_{false};
|
||||
int custom_bprop_cell_count_{0};
|
||||
|
||||
// Used in sub thread
|
||||
size_t cell_order_{0};
|
||||
std::string cur_cell_id_{""};
|
||||
|
||||
size_t obj_order_{0};
|
||||
// If grad_order=1, indicate first derivative; grad_order=2, indicate second derivative; ...
|
||||
size_t grad_order_{0};
|
||||
std::string grad_operation_;
|
||||
TopCellInfoPtr top_cell_{nullptr};
|
||||
TopCellInfoPtr pre_top_cell_{nullptr};
|
||||
InputArgsInfoPtr top_input_args_info_{nullptr};
|
||||
// Records every cell info for share, regardless of whether need construct grad graph
|
||||
std::stack<InputArgsInfoPtr> input_args_info_stack_;
|
||||
|
|
|
@ -75,8 +75,8 @@ size_t GetOutputTensorNumForTuple(const CNodePtr &make_tuple) {
|
|||
} // namespace
|
||||
|
||||
void MsFunction::RunReplace(const CNodePtr &added_make_tuple,
|
||||
const std::vector<tensor::TensorPtr> &total_output_tensors,
|
||||
const FuncGraphPtr &grad_graph) const {
|
||||
const std::vector<tensor::TensorPtr> &total_output_tensors, const FuncGraphPtr &grad_graph,
|
||||
bool is_dynamic_shape) const {
|
||||
MS_EXCEPTION_IF_NULL(added_make_tuple);
|
||||
size_t index = 0;
|
||||
for (size_t i = 1; i < added_make_tuple->size(); ++i) {
|
||||
|
@ -110,6 +110,17 @@ void MsFunction::RunReplace(const CNodePtr &added_make_tuple,
|
|||
} else {
|
||||
MS_LOG(EXCEPTION) << "The output value of forward cnode is empty, forward cnode info: " << cnode->ToString();
|
||||
}
|
||||
if (is_dynamic_shape) {
|
||||
if (output_num == 1) {
|
||||
output_vnode->set_abstract(new_values[0]->ToAbstract()->Broaden());
|
||||
} else {
|
||||
AbstractBasePtrList abs_list;
|
||||
for (size_t j = 0; j < output_num; ++j) {
|
||||
(void)abs_list.emplace_back(new_values[j]->ToAbstract()->Broaden());
|
||||
}
|
||||
output_vnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "New output value node: " << output_vnode->ToString();
|
||||
}
|
||||
// Save op info with new tensors for current running ms_function func graph.
|
||||
|
@ -119,9 +130,8 @@ void MsFunction::RunReplace(const CNodePtr &added_make_tuple,
|
|||
}
|
||||
}
|
||||
|
||||
void MsFunction::ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, const ValuePtr &added_out,
|
||||
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph) const {
|
||||
MS_EXCEPTION_IF_NULL(top_cell);
|
||||
void MsFunction::ReplaceNewTensorsInGradGraph(const ValuePtr &added_out, const FuncGraphPtr &ms_func_graph,
|
||||
const FuncGraphPtr &grad_graph, bool is_dynamic_shape) const {
|
||||
MS_EXCEPTION_IF_NULL(ms_func_graph);
|
||||
// Get added forward nodes.
|
||||
auto merge_node = ms_func_graph->output();
|
||||
|
@ -150,7 +160,7 @@ void MsFunction::ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, co
|
|||
// The forward node in ms_function graph is created during compilation and is a
|
||||
// placeholder(mindspore/ccsrc/frontend/optimizer/ad/pynative_dfunctor.cc).After running ms_function, need to update
|
||||
// to real value.
|
||||
RunReplace(added_make_tuple, total_output_tensors, grad_graph);
|
||||
RunReplace(added_make_tuple, total_output_tensors, grad_graph, is_dynamic_shape);
|
||||
}
|
||||
|
||||
void MsFunction::UpdateMsFunctionForwardTensors(const GradExecutor *grad_executor, const string &op_info,
|
||||
|
@ -190,11 +200,10 @@ void MsFunction::GetInputArgsNode(const FrontendOpRunInfoPtr &op_run_info, AnfNo
|
|||
for (size_t i = 0; i < op_run_info->input_size; ++i) {
|
||||
const auto &input_i_value = op_run_info->input_value[i];
|
||||
const auto &id = PyNativeAlgo::Common::GetIdByValue(input_i_value);
|
||||
MS_LOG(DEBUG) << "The input " << i << " id " << id
|
||||
<< " value of ms_function graph is: " << input_i_value->ToString();
|
||||
const auto &input_i_node = grad_executor->GetInput(input_i_value, id);
|
||||
MS_EXCEPTION_IF_NULL(input_i_node);
|
||||
MS_LOG(DEBUG) << "The input " << i << " node of ms_function graph is: " << input_i_node->DebugString();
|
||||
MS_LOG(DEBUG) << "The input " << i << " id " << id << " value is: " << input_i_value->ToString()
|
||||
<< ", node is: " << input_i_node->DebugString();
|
||||
(void)input_nodes->emplace_back(input_i_node);
|
||||
}
|
||||
}
|
||||
|
@ -230,11 +239,11 @@ void MsFunction::GetWeightsNode(const FrontendOpRunInfoPtr &op_run_info, const G
|
|||
} else {
|
||||
top_cell->fg()->add_parameter(param);
|
||||
param->debug_info()->set_name(param->name());
|
||||
top_cell->SetParamNodeMapInGraphInfoMap(tensor_value->id(), param, true);
|
||||
}
|
||||
(void)new_params.emplace_back(param);
|
||||
(void)input_nodes->emplace_back(param);
|
||||
(void)op_run_info->input_value.emplace_back(tensor_value);
|
||||
top_cell->SetParamNodeMapInGraphInfoMap(tensor_value->id(), param, true);
|
||||
MS_LOG(DEBUG) << "Top graph set free parameter " << param->DebugString() << ". Its default value is "
|
||||
<< tensor_value->ToString() << ". Its name is: " << param->name();
|
||||
}
|
||||
|
@ -254,7 +263,6 @@ void MsFunction::MakeCNodeForMsFunction(const FrontendOpRunInfoPtr &op_run_info,
|
|||
// Make a CNode which includes ms_function fprop graph and inputs node
|
||||
MS_EXCEPTION_IF_NULL(ms_function_cnode);
|
||||
*ms_function_cnode = grad_executor->top_cell()->fg()->NewCNode(input_nodes);
|
||||
|
||||
MS_LOG(DEBUG) << "Make ms function forward CNode: " << (*ms_function_cnode)->DebugString();
|
||||
}
|
||||
|
||||
|
@ -291,7 +299,7 @@ void MsFunction::AsyncKPynativeWithFProp(const GradExecutor *grad_executor,
|
|||
const ad::AutoGradCellImplPtr &auto_grad_cell_ptr,
|
||||
const ad::GradParamPtr &grad_param) const {
|
||||
MS_EXCEPTION_IF_NULL(grad_executor);
|
||||
const auto fn = [this, grad_param, auto_grad_cell_ptr]() {
|
||||
const auto fn = [grad_param, auto_grad_cell_ptr]() {
|
||||
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
|
||||
if (!auto_grad_cell_ptr->KPynativeWithFProp(grad_param)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode";
|
||||
|
@ -317,6 +325,7 @@ void MsFunction::GradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, co
|
|||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
MS_EXCEPTION_IF_NULL(grad_executor);
|
||||
MS_LOG(DEBUG) << "ms_function actual output value: " << op_run_info->out_value->ToString();
|
||||
|
||||
// Step 1: Update actual output tensors used in grad graph.
|
||||
MS_EXCEPTION_IF_NULL(op_run_info->out_value);
|
||||
MS_LOG(DEBUG) << "ms_function actual output value: " << op_run_info->out_value->ToString();
|
||||
|
@ -328,8 +337,12 @@ void MsFunction::GradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, co
|
|||
grad_executor->top_cell()->op_info_with_ms_func_forward_tensors().end()) {
|
||||
UpdateMsFunctionForwardTensors(grad_executor, op_run_info->op_info, added_out_v);
|
||||
}
|
||||
|
||||
ReplaceNewTensorsInGradGraph(grad_executor->top_cell(), added_out_v, ms_func_graph, grad_graph);
|
||||
bool is_dynamic_shape = common::AnfAlgo::IsDynamicShape(ms_func_graph->output());
|
||||
if (is_dynamic_shape) {
|
||||
const_cast<GradExecutor *>(grad_executor)->set_use_dynamic_shape_process(true);
|
||||
MS_LOG(DEBUG) << "Ms function is dynamic shape";
|
||||
}
|
||||
ReplaceNewTensorsInGradGraph(added_out_v, ms_func_graph, grad_graph, is_dynamic_shape);
|
||||
|
||||
// Clone new ms_function func graph and grad graph.
|
||||
auto new_ms_func_graph = BasicClone(ms_func_graph);
|
||||
|
|
|
@ -49,9 +49,9 @@ class MsFunction {
|
|||
const ad::GradParamPtr &grad_param) const;
|
||||
// Update device address of value node in grad graph by forward tensors.
|
||||
void RunReplace(const CNodePtr &added_make_tuple, const std::vector<tensor::TensorPtr> &total_output_tensors,
|
||||
const FuncGraphPtr &grad_graph) const;
|
||||
void ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, const ValuePtr &added_out,
|
||||
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph) const;
|
||||
const FuncGraphPtr &grad_graph, bool is_dynamic_shape) const;
|
||||
void ReplaceNewTensorsInGradGraph(const ValuePtr &added_out, const FuncGraphPtr &ms_func_graph,
|
||||
const FuncGraphPtr &grad_graph, bool is_dynamic_shape) const;
|
||||
void UpdateMsFunctionForwardTensors(const GradExecutor *grad_executor, const string &op_info,
|
||||
const ValuePtr &new_forward_value) const;
|
||||
// Make CNode for ms_function forward graph.
|
||||
|
|
|
@ -91,6 +91,12 @@ void TopCellInfo::GetOpInfo(const FrontendOpRunInfoPtr &op_run_info) {
|
|||
++op_index_;
|
||||
}
|
||||
|
||||
void TopCellInfo::UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compile) {
|
||||
need_compile_graph_ = need_compile_graph;
|
||||
forward_already_run_ = forward_already_run;
|
||||
vm_compile_ = vm_compile;
|
||||
}
|
||||
|
||||
void TopCellInfo::ClearDeviceMemory() const {
|
||||
MS_LOG(DEBUG) << "Clear device memory in value nodes of bprop graph, top cell: " << cell_id_;
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
|
@ -127,6 +133,23 @@ void TopCellInfo::ClearDeviceMemory() const {
|
|||
}
|
||||
}
|
||||
|
||||
void TopCellInfo::Clear() {
|
||||
MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_;
|
||||
hook_changed_ = false;
|
||||
is_init_kpynative_ = false;
|
||||
need_compile_graph_ = false;
|
||||
forward_already_run_ = false;
|
||||
vm_compile_ = false;
|
||||
op_index_ = 0;
|
||||
resource_ = nullptr;
|
||||
fg_ = nullptr;
|
||||
graph_info_map_.clear();
|
||||
op_info_with_tensor_id_.clear();
|
||||
tensor_id_with_tensor_object_.clear();
|
||||
op_info_with_ms_func_forward_tensors_.clear();
|
||||
cnode_hash_with_op_index_.clear();
|
||||
}
|
||||
|
||||
void TopCellInfo::DeleteParamNodeInfo(const FuncGraphPtr &g, const std::string &id) {
|
||||
auto &graph_info = graph_info_map().at(g);
|
||||
MS_EXCEPTION_IF_NULL(graph_info);
|
||||
|
@ -188,22 +211,6 @@ void TopCellInfo::SetNestedMultipleOutputToGraphInfoMap(const string &id, const
|
|||
}
|
||||
}
|
||||
|
||||
void TopCellInfo::Clear() {
|
||||
MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_;
|
||||
hook_changed_ = false;
|
||||
is_init_kpynative_ = false;
|
||||
need_compile_graph_ = false;
|
||||
forward_already_run_ = false;
|
||||
op_index_ = 0;
|
||||
resource_ = nullptr;
|
||||
fg_ = nullptr;
|
||||
graph_info_map_.clear();
|
||||
op_info_with_tensor_id_.clear();
|
||||
tensor_id_with_tensor_object_.clear();
|
||||
op_info_with_ms_func_forward_tensors_.clear();
|
||||
cnode_hash_with_op_index_.clear();
|
||||
}
|
||||
|
||||
void TopCellInfo::SetUnpackOutputToGraphInfoMap(const std::string &id, const AnfNodePtr &node,
|
||||
const std::vector<int64_t> &index) const {
|
||||
auto &graph_info = graph_info_map().at(fg());
|
||||
|
|
|
@ -58,11 +58,11 @@ using GraphInfoPtr = std::shared_ptr<GraphInfo>;
|
|||
class TopCellInfo {
|
||||
public:
|
||||
~TopCellInfo() = default;
|
||||
TopCellInfo(bool is_high_order_top_cell, size_t grad_order, std::string c_cell_id, std::string cellid,
|
||||
TopCellInfo(bool is_high_order_top_cell, size_t grad_order, std::string obj_id_with_grad_order, std::string cellid,
|
||||
std::string already_run_cell_id, pipeline::ResourcePtr r, FuncGraphPtr fg)
|
||||
: is_high_order_top_cell_(is_high_order_top_cell),
|
||||
grad_order_(grad_order),
|
||||
c_cell_id_(std::move(c_cell_id)),
|
||||
obj_id_with_grad_order_(std::move(obj_id_with_grad_order)),
|
||||
cell_id_(std::move(cellid)),
|
||||
already_run_cell_id_(std::move(already_run_cell_id)),
|
||||
resource_(std::move(r)),
|
||||
|
@ -79,11 +79,12 @@ class TopCellInfo {
|
|||
inline void ClearCellHookOp() { cell_backward_hook_op_.clear(); }
|
||||
inline bool forward_already_run() const { return forward_already_run_; }
|
||||
inline void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
|
||||
inline bool need_compile_graph() const { return need_compile_graph_; }
|
||||
inline void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
|
||||
inline bool vm_compile() const { return vm_compile_; }
|
||||
inline bool is_high_order_top_cell() const { return is_high_order_top_cell_; }
|
||||
inline void set_need_do_final_opt(bool need_do_final_opt) { need_do_final_opt_ = need_do_final_opt; }
|
||||
inline bool need_do_final_opt() const { return need_do_final_opt_; }
|
||||
inline bool need_compile_graph() const { return need_compile_graph_; }
|
||||
inline void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
|
||||
inline pipeline::ResourcePtr resource() const { return resource_; }
|
||||
inline FuncGraphPtr fg() const {
|
||||
MS_EXCEPTION_IF_NULL(fg_);
|
||||
|
@ -91,7 +92,7 @@ class TopCellInfo {
|
|||
}
|
||||
inline void set_fg(const FuncGraphPtr &fg) { fg_ = fg; }
|
||||
inline const std::string &cell_id() const { return cell_id_; }
|
||||
inline const std::string &c_cell_id() const { return c_cell_id_; }
|
||||
inline const std::string &obj_id_with_grad_order() const { return obj_id_with_grad_order_; }
|
||||
inline const std::string &already_run_cell_id() const { return already_run_cell_id_; }
|
||||
inline void set_input_args_id(const std::string &input_args_id) { input_args_id_ = input_args_id; }
|
||||
inline const std::string &input_args_id() const { return input_args_id_; }
|
||||
|
@ -124,10 +125,11 @@ class TopCellInfo {
|
|||
inline void set_cnode_hash_with_op_index(const size_t &node_hash, const size_t &op_index) {
|
||||
cnode_hash_with_op_index_[node_hash] = op_index;
|
||||
}
|
||||
inline size_t get_op_index_by_cnode_hash(const size_t &node_hash) {
|
||||
auto iter = cnode_hash_with_op_index_.find(node_hash);
|
||||
inline size_t get_op_index_by_cnode_hash(const size_t node_hash, const size_t node_idx) const {
|
||||
const auto iter = cnode_hash_with_op_index_.find(node_hash);
|
||||
if (iter == cnode_hash_with_op_index_.end()) {
|
||||
MS_LOG(EXCEPTION) << "hash:" << node_hash << " is not found in cnode_hash_with_op_index_";
|
||||
MS_LOG(DEBUG) << "hash:" << node_hash << " is not found in cnode_hash_with_op_index_";
|
||||
return node_idx;
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
@ -137,7 +139,8 @@ class TopCellInfo {
|
|||
void DeleteParamNodeInfo(const FuncGraphPtr &g, const std::string &id);
|
||||
void SetParamNodeMapInGraphInfoMap(const std::string &id, const ParameterPtr ¶m, bool is_weight = false) const;
|
||||
void SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index = -1,
|
||||
bool save_flag = true) const;
|
||||
bool need_save_sub_id = true) const;
|
||||
void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compile);
|
||||
void ClearDeviceMemory() const;
|
||||
|
||||
private:
|
||||
|
@ -150,12 +153,13 @@ class TopCellInfo {
|
|||
bool hook_changed_{false};
|
||||
bool is_init_kpynative_{false};
|
||||
bool forward_already_run_{false};
|
||||
bool need_compile_graph_{false};
|
||||
bool vm_compile_{false};
|
||||
bool is_high_order_top_cell_{false};
|
||||
bool need_do_final_opt_{false};
|
||||
bool need_compile_graph_{false};
|
||||
size_t op_index_{0};
|
||||
size_t grad_order_{0};
|
||||
std::string c_cell_id_;
|
||||
std::string obj_id_with_grad_order_;
|
||||
std::string cell_id_;
|
||||
std::string already_run_cell_id_;
|
||||
std::string input_args_id_;
|
||||
|
|
|
@ -213,6 +213,7 @@ void PyNativeExecutor::SetMsFunctionCompileStatus(bool is_compiling) const {
|
|||
|
||||
void PyNativeExecutor::SetDynamicInput(const py::object &cell, const py::args &args) const {
|
||||
grad_executor()->set_use_dynamic_shape_process(true);
|
||||
MS_LOG(DEBUG) << "Set dynamic shape by set inputs";
|
||||
}
|
||||
|
||||
void RegPyNativeExecutor(const py::module *m) {
|
||||
|
|
|
@ -29,6 +29,18 @@
|
|||
namespace mindspore {
|
||||
namespace pynative {
|
||||
namespace PyNativeAlgo {
|
||||
namespace {
|
||||
void ClonePrim(const FrontendOpRunInfoPtr &op_run_info) {
|
||||
// Clone a new prim
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
op_run_info->op_prim = std::make_shared<PrimitivePy>(*(op_run_info->op_prim));
|
||||
MS_EXCEPTION_IF_NULL(op_run_info->op_prim->adapter());
|
||||
if (op_run_info->op_prim->adapter()->attached_primitive() == nullptr) {
|
||||
op_run_info->op_prim->adapter()->set_attached_primitive(op_run_info->op_prim);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::string Common::GetIdByValue(const ValuePtr &v) {
|
||||
MS_EXCEPTION_IF_NULL(v);
|
||||
if (v->isa<tensor::Tensor>()) {
|
||||
|
@ -101,6 +113,26 @@ bool Common::IsTensor(const ValuePtr &v, bool include_sequence) {
|
|||
return v->isa<tensor::Tensor>() || v->isa<tensor::MetaSparseTensor>();
|
||||
}
|
||||
|
||||
ValuePtr Common::FilterSensValues(const ValuePtr &value) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<tensor::Tensor>() || value->isa<tensor::COOTensor>() || value->isa<tensor::CSRTensor>()) {
|
||||
return value;
|
||||
} else if (value->isa<ValueSequence>()) {
|
||||
std::vector<ValuePtr> value_list;
|
||||
auto value_seq = value->cast<ValueSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_seq);
|
||||
for (auto &filter_value : value_seq->value()) {
|
||||
if (FilterSensValues(filter_value) != nullptr) {
|
||||
(void)value_list.emplace_back(filter_value);
|
||||
}
|
||||
}
|
||||
return std::make_shared<ValueTuple>(value_list);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Value type: " << value->ToString();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
tensor::TensorPtr Common::GetTensorFromParam(const AnfNodePtr ¶m_node) {
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
auto param = param_node->cast<ParameterPtr>();
|
||||
|
@ -523,12 +555,8 @@ void DataConvert::GetInputTensor(const FrontendOpRunInfoPtr &op_run_info, const
|
|||
bool need_convert_input_to_attr = NeedConvertConstInputToAttr(op_run_info, device_target, &input_to_attr);
|
||||
MS_LOG(DEBUG) << "Need convert input to addr " << need_convert_input_to_attr;
|
||||
if (need_convert_input_to_attr) {
|
||||
// Clone a new prim
|
||||
op_run_info->op_prim = std::make_shared<PrimitivePy>(*(op_run_info->op_prim));
|
||||
MS_EXCEPTION_IF_NULL(op_run_info->op_prim->adapter());
|
||||
if (op_run_info->op_prim->adapter()->attached_primitive() == nullptr) {
|
||||
op_run_info->op_prim->adapter()->set_attached_primitive(op_run_info->op_prim);
|
||||
}
|
||||
// Prim may be changed attr
|
||||
ClonePrim(op_run_info);
|
||||
}
|
||||
const auto &op_prim = op_run_info->op_prim;
|
||||
|
||||
|
@ -544,10 +572,17 @@ void DataConvert::GetInputTensor(const FrontendOpRunInfoPtr &op_run_info, const
|
|||
// Mark tensors, common tensor data : 0, weight param: 1, valuenode(float_, int_): 2
|
||||
ConvertValueToTensor(op_run_info, input_object, index, op_prim);
|
||||
// -1 indicates input_object is not a dynInput
|
||||
if (op_prim->HasAttr(kAttrDynInputSizes) && !input_object->isa<ValueSequence>()) {
|
||||
auto dyn_v = GetValue<const std::vector<int64_t>>(op_prim->GetAttr(kAttrDynInputSizes));
|
||||
(void)dyn_v.emplace_back(-1);
|
||||
op_prim->set_attr(kAttrDynInputSizes, MakeValue(dyn_v));
|
||||
if (op_prim->HasAttr(kAttrDynInputSizes)) {
|
||||
if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)) {
|
||||
// Like addn, prim define in python, but number of inputs change, so the value of kAttrDynInputSizes
|
||||
// changed too. In async, do opgrad may be not complete.
|
||||
ClonePrim(op_run_info);
|
||||
}
|
||||
if (!input_object->isa<ValueSequence>()) {
|
||||
auto dyn_v = GetValue<const std::vector<int64_t>>(op_prim->GetAttr(kAttrDynInputSizes));
|
||||
(void)dyn_v.emplace_back(-1);
|
||||
op_prim->set_attr(kAttrDynInputSizes, MakeValue(dyn_v));
|
||||
}
|
||||
}
|
||||
}
|
||||
op_prim->EndRecordAddAttr();
|
||||
|
|
|
@ -33,6 +33,7 @@ struct Common {
|
|||
static std::string GetIdByValue(const ValuePtr &v);
|
||||
static bool ValueHasDynamicShape(const ValuePtr &value);
|
||||
static bool IsTensor(const ValuePtr &v, bool include_sequence = false);
|
||||
static ValuePtr FilterSensValues(const ValuePtr &value);
|
||||
static tensor::TensorPtr GetTensorFromParam(const AnfNodePtr ¶m_node);
|
||||
static void SetForwardOutputFlag(const ValuePtr &v);
|
||||
static void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph);
|
||||
|
|
|
@ -110,6 +110,7 @@ OpCompilerInfoPtr OpCompiler::Compile(const session::BackendOpRunInfoPtr &op_run
|
|||
device::DeviceContext *device_context) {
|
||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
py::gil_scoped_acquire acquire_gil;
|
||||
auto graph_info = op_run_info->base_op_run_info.graph_info;
|
||||
auto iter = op_compiler_infos_.find(graph_info);
|
||||
// Check if the graph cache exists.
|
||||
|
@ -160,7 +161,6 @@ OpCompilerInfoPtr OpCompiler::Compile(const session::BackendOpRunInfoPtr &op_run
|
|||
auto op_compiler_info =
|
||||
std::make_shared<OpCompilerInfo>(graph_info, graph->graph_id(), graph, outputs_with_index, device_context, false);
|
||||
|
||||
py::gil_scoped_acquire acquire_gil;
|
||||
op_compiler_infos_[graph_info] = op_compiler_info;
|
||||
return op_compiler_info;
|
||||
}
|
||||
|
|
|
@ -225,3 +225,9 @@ def bprop_scalar_calc(x, y, out, dout):
|
|||
def bprop_scalar_not(x, out, dout):
|
||||
"""Backpropagator for primitive `bool_not` and `string_not`."""
|
||||
return (C.zeros_like(x),)
|
||||
|
||||
|
||||
@bprops.register("TensorMove")
|
||||
def bprop_tensor_move(x, out, dout):
|
||||
"""Backpropagator for primitive `TensorMove`."""
|
||||
return (dout,)
|
||||
|
|
|
@ -1034,5 +1034,6 @@ class RandomShuffle(Primitive):
|
|||
def __init__(self, seed=0, seed2=0):
|
||||
"""Initialize RandomShuffle"""
|
||||
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
|
||||
self.add_prim_attr("side_effect_hidden", True)
|
||||
Validator.check_non_negative_int(seed, "seed", self.name)
|
||||
Validator.check_non_negative_int(seed2, "seed2", self.name)
|
||||
|
|
|
@ -213,9 +213,9 @@ FuncGraphManagerPtr Make_Manager(int64_t condition = 0) {
|
|||
/// Description:
|
||||
/// Expectation: the python path is right
|
||||
TEST_F(TestStepParallel, GetPythonPath1) {
|
||||
OperatorName operator_name = "AllReduce";
|
||||
const char *operator_name = "AllReduce";
|
||||
const std::string expect = "mindspore.ops.operations";
|
||||
auto temp = parallel::GetOpPythonPath(operator_name);
|
||||
std::string temp = parallel::GetOpPythonPath(operator_name);
|
||||
ASSERT_EQ(temp, expect);
|
||||
}
|
||||
|
||||
|
@ -223,9 +223,9 @@ TEST_F(TestStepParallel, GetPythonPath1) {
|
|||
/// Description:
|
||||
/// Expectation: the python path is right
|
||||
TEST_F(TestStepParallel, GetPythonPath2) {
|
||||
OperatorName operator_name = "Add";
|
||||
const char *operator_name = "Add";
|
||||
const std::string expect = "mindspore.ops.operations";
|
||||
auto temp = parallel::GetOpPythonPath(operator_name);
|
||||
std::string temp = parallel::GetOpPythonPath(operator_name);
|
||||
ASSERT_EQ(temp, expect);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue