!46504 Fix bug for PyNative

Merge pull request !46504 from zjun/fix_high
This commit is contained in:
i-robot 2022-12-08 06:12:41 +00:00 committed by Gitee
commit b34dd895f3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
23 changed files with 687 additions and 590 deletions

View File

@ -375,6 +375,17 @@ void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler, con
MS_EXCEPTION_IF_NULL(make_tuple);
args_tuple_num = make_tuple->inputs().size() - 1;
continue;
} else if (input_node->isa<Parameter>()) {
auto param = input_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
auto abs = param->abstract();
MS_EXCEPTION_IF_NULL(abs);
if (abs->isa<abstract::AbstractTuple>() && !abs->isa<abstract::AbstractSparseTensor>()) {
auto abs_tuple = abs->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(abs_tuple);
args_tuple_num = abs_tuple->elements().size();
continue;
}
}
}
// Hook single-input or single-output.

View File

@ -33,6 +33,7 @@
#include "utils/profile.h"
#include "include/common/utils/primitive_utils.h"
#include "pipeline/jit/pass.h"
namespace mindspore {
namespace ad {
namespace {
@ -43,8 +44,9 @@ constexpr char kAttrOnesLikeCOO[] = "ones_like_coo_node";
enum class SpecialType { kZerosLikeType = 0, kOnesLikeType = 1 };
const std::map<SpecialType, std::shared_ptr<Primitive>> kValueType{{SpecialType::kZerosLikeType, prim::kPrimZerosLike},
{SpecialType::kOnesLikeType, prim::kPrimOnesLike}};
const std::vector<PrimitivePtr> kGradBlackList{prim::kPrimMakeTuple, prim::kPrimTupleGetItem, prim::kPrimStopGradient,
prim::kPrimUpdateState, prim::kPrimNPUAllocFloatStatus};
const std::vector<PrimitivePtr> kGradBlackList{
prim::kPrimMakeTuple, prim::kPrimTupleGetItem, prim::kPrimStopGradient, prim::kPrimUpdateState,
prim::kPrimNPUAllocFloatStatus, prim::kPrimNPUGetFloatStatus, prim::kPrimNPUClearFloatStatus};
AnfNodePtr BuildSpecialLikeValue(const FuncGraphPtr &tape, const ValuePtr &value, const SpecialType &type);
void ClearDeviceAddress(const ValuePtr &value) {
std::vector<tensor::TensorPtr> tensors;
@ -55,26 +57,6 @@ void ClearDeviceAddress(const ValuePtr &value) {
}
}
ValuePtr FilterSensValues(const ValuePtr &value) {
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>() || value->isa<tensor::COOTensor>() || value->isa<tensor::CSRTensor>()) {
return value;
} else if (value->isa<ValueSequence>()) {
std::vector<ValuePtr> value_list;
auto value_seq = value->cast<ValueSequencePtr>();
MS_EXCEPTION_IF_NULL(value_seq);
for (auto filter_value : value_seq->value()) {
if (FilterSensValues(filter_value) != nullptr) {
(void)value_list.emplace_back(filter_value);
}
}
return std::make_shared<ValueTuple>(value_list);
} else {
MS_LOG(DEBUG) << "value type: " << value->ToString();
return nullptr;
}
}
bool IsPrimNeedGrad(const PrimitivePtr &prim) {
for (const auto &no_need_grad_prim : kGradBlackList) {
if (IsPrimitiveEquals(prim, no_need_grad_prim)) {
@ -234,6 +216,7 @@ bool IsZerosLikeNode(const AnfNodePtr &node) {
}
FuncGraphPtr OptimizeBpropBuilder(const FuncGraphPtr &bprop_func_graph) {
pynative::PyNativeAlgo::Common::DumpGraphIR("bprop_builder_before_opt.ir", bprop_func_graph);
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
resource->set_func_graph(bprop_func_graph);
auto manager = resource->manager();
@ -243,6 +226,34 @@ FuncGraphPtr OptimizeBpropBuilder(const FuncGraphPtr &bprop_func_graph) {
pynative::PyNativeAlgo::Common::DumpGraphIR("bprop_builder_after_opt.ir", after_opt_bg);
return after_opt_bg;
}
bool IsOutputBothEmpty(const AnfNodePtr &inputs_grad, const AnfNodePtr &weights_grad) {
if (!inputs_grad->isa<CNode>() || !weights_grad->isa<CNode>()) {
return false;
}
auto inputs_grad_cnode = inputs_grad->cast<CNodePtr>();
auto weights_grad_cnode = weights_grad->cast<CNodePtr>();
if (!IsPrimitiveCNode(inputs_grad_cnode, prim::kPrimMakeTuple) ||
!IsPrimitiveCNode(weights_grad_cnode, prim::kPrimMakeTuple)) {
return false;
}
constexpr int kEmptyTupeSize = 1;
if (inputs_grad_cnode->size() != kEmptyTupeSize || weights_grad_cnode->size() != kEmptyTupeSize) {
return false;
}
return true;
}
AnfNodePtr GenerateEmptyTupleValue() {
std::vector<ValuePtr> value_list;
auto inputs_value = std::make_shared<ValueTuple>(value_list);
auto weights_value = std::make_shared<ValueTuple>(value_list);
std::vector<ValuePtr> tuple_list{inputs_value, weights_value};
auto tuple_value = std::make_shared<ValueTuple>(tuple_list);
auto tuple_value_node = NewValueNode(tuple_value);
tuple_value_node->set_abstract(tuple_value->ToAbstract());
return tuple_value_node;
}
} // namespace
AnfNodePtr FunctionNode::HyperAdd(const AnfNodePtr &left_node, const AnfNodePtr &right_node) {
@ -302,14 +313,15 @@ void FunctionNode::ReplaceEdges() {
AutoGradCellImpl::AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std::vector<ValuePtr> &input_param_values)
: tape_(std::make_shared<FuncGraph>()), cell_inputs_(cell_inputs) {
tape_->debug_info()->set_name("grad_top");
MS_LOG(DEBUG) << "Start AutoGradCellImpl, cell_inputs size: " << cell_inputs.size();
for (size_t i = 0; i < cell_inputs.size(); ++i) {
TraceGuard trace_guard(std::make_shared<TraceCopy>(cell_inputs[i]->debug_info()));
auto parameter = tape_->add_parameter();
parameter->set_abstract(input_param_values[i]->ToAbstract()->Broaden());
auto zeros_like_dout = BuildZerosLikeNode(tape_, input_param_values[i]);
auto func_node = std::make_shared<FunctionNode>(tape_, zeros_like_dout);
auto input_adjoint = std::make_shared<VariableNode>(func_node, input_param_values[i]);
anfnode_to_variable_adjoint_.insert(std::make_pair(cell_inputs[i], input_adjoint));
auto input_adjoint = std::make_shared<VariableAdjoint>(func_node, input_param_values[i]);
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(cell_inputs[i], input_adjoint));
}
}
@ -329,12 +341,14 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
ClearDeviceAddress(cloned_value);
AnfNodePtr dout = BuildSpecialLikeValue(tape_, cloned_value, SpecialType::kZerosLikeType);
auto fn = std::make_shared<FunctionNode>(tape_, dout);
auto variable_adjoint = std::make_shared<VariableNode>(fn, cloned_value);
auto variable_adjoint = std::make_shared<VariableAdjoint>(fn, cloned_value);
if (!grad_param->grad_by_value) {
BuildKNode(grad_param, variable_adjoint);
need_do_manager_replace_ = true;
}
CNodePtr input_node = ConstructBpropGraphInput(grad_param, dout);
CNodePtr input_node = ConstructBpropGraphInput(grad_param, dout, variable_adjoint);
MS_LOG(DEBUG) << "Construct input cnode: " << input_node->DebugString();
// Gradient outputs
std::vector<CNodePtr> outputs;
#ifndef ENABLE_TEST
if (IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook)) {
@ -342,7 +356,7 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
} else {
mindspore::BuildBprop(input_node, &outputs, &users_);
if (outputs.empty()) {
MS_LOG(DEBUG) << "expander has no bprop of this prim: " << grad_param->cnode->DebugString();
MS_LOG(DEBUG) << "Expander has no bprop of this prim: " << grad_param->cnode->DebugString();
BuildCustomBpropCNode(input_node, &outputs);
}
}
@ -356,11 +370,11 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
if (!outputs.empty()) {
UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args);
} else {
MS_LOG(DEBUG) << "this op has not custom bprop: " << grad_param->cnode->DebugString();
MS_LOG(DEBUG) << "This op has not custom bprop: " << grad_param->cnode->DebugString();
variable_adjoint->set_is_fake_bprop(true);
variable_adjoint->set_fake_prim_name(prim->name());
}
anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
// record last_node for brackpropagate
last_node_ = grad_param->cnode;
return true;
@ -368,9 +382,9 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
MS_EXCEPTION_IF_NULL(grad_param);
AnfNodePtrList args_node_list;
CNodePtr bprop_cnode = nullptr;
AnfNodePtr k_node = nullptr;
AnfNodePtr dout = nullptr;
if (grad_param->grad_by_value) {
for (size_t i = 0; i < grad_param->op_args.size(); ++i) {
@ -391,7 +405,11 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
BuildKNodeListFromPrimalCNode(grad_param->cnode, grad_param->op_args, &args_node_list);
bprop_cnode = GetBPropFromFProp(grad_param->fprop_fg, args_node_list, grad_param->out, &dout);
}
auto fn = std::make_shared<FunctionNode>(tape_, dout);
auto variable_adjoint = std::make_shared<VariableAdjoint>(fn, grad_param->out);
if (!grad_param->grad_by_value) {
BuildKNode(grad_param, variable_adjoint);
}
std::vector<CNodePtr> outputs;
for (size_t i = 1; i < grad_param->cnode->size(); ++i) {
// bprop_app[0] env
@ -399,11 +417,8 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
din->set_abstract(grad_param->op_args[i - 1]->ToAbstract()->Broaden());
(void)outputs.emplace_back(din);
}
auto fn = std::make_shared<FunctionNode>(tape_, dout);
auto variable_adjoint = std::make_shared<VariableNode>(fn, grad_param->out);
variable_adjoint->set_k_node(k_node);
UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args);
anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
need_do_manager_replace_ = true;
return true;
}
@ -426,7 +441,7 @@ CNodePtr AutoGradCellImpl::GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const
auto get_bprop =
bprop_builder->NewCNode({NewValueNode(prim::kPrimTupleGetItem), fprop_app, NewValueNode(static_cast<int64_t>(1))});
// Get graph after optimize
// Get bprop from fprop_fg, it is 2th output of fprop_fg
AnfNodePtrList node_list{get_bprop};
auto dout = bprop_builder->add_parameter();
MS_EXCEPTION_IF_NULL(out);
@ -434,12 +449,15 @@ CNodePtr AutoGradCellImpl::GetBPropFromFProp(const FuncGraphPtr &fprop_fg, const
(void)node_list.emplace_back(dout);
auto call_bprop = bprop_builder->NewCNode(node_list);
bprop_builder->set_output(call_bprop);
// Call pass for optimize graph, such as inline
auto after_opt_fg = OptimizeBpropBuilder(bprop_builder);
// Call by tape_
MS_EXCEPTION_IF_NULL(tape_dout);
*tape_dout = BuildZerosLikeNode(tape_, out);
(void)bprop_builder_inputs.emplace_back(*tape_dout);
bprop_builder_inputs.insert(bprop_builder_inputs.cbegin(), NewValueNode(after_opt_fg));
(void)bprop_builder_inputs.insert(bprop_builder_inputs.cbegin(), NewValueNode(after_opt_fg));
get_bprop = tape_->NewCNode(bprop_builder_inputs);
// tape_dout is set by next op
AddUser(*tape_dout, get_bprop, bprop_builder_inputs.size() - 1);
@ -451,7 +469,7 @@ void AutoGradCellImpl::UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node,
MS_EXCEPTION_IF_NULL(sens_out);
MS_LOG(DEBUG) << "Real output node of top cell is " << output_node->DebugString();
last_node_ = output_node;
sens_value_ = FilterSensValues(sens_out);
sens_value_ = sens_out;
}
FuncGraphPtr AutoGradCellImpl::Finish(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
@ -463,45 +481,46 @@ FuncGraphPtr AutoGradCellImpl::Finish(const AnfNodePtrList &weights, const std::
if (!last_node_->isa<ValueNode>() && !last_node_->isa<Parameter>()) {
(void)BackPropagate();
}
// Return the gradient;
if (grad_attr.get_by_position && grad_position.empty()) {
MS_LOG(EXCEPTION) << "grad_position should not be empty when grad by position!";
}
SetOutput(weights, grad_position, grad_attr);
// Replace Parameter of primal funcgraph with parameter of tape_;
// Replace Parameter of primal funcgraph with parameter of tape_;
ReplacePrimalParameter(weights, grad_attr.has_sens);
pynative::PyNativeAlgo::Common::DumpGraphIR("before_final_opt.ir", tape_);
return tape_;
}
CNodePtr AutoGradCellImpl::ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout) {
CNodePtr AutoGradCellImpl::ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout,
const VariableAdjointPtr &variable_adjoint) {
MS_EXCEPTION_IF_NULL(grad_param);
std::vector<AnfNodePtr> node_list;
(void)node_list.emplace_back(grad_param->cnode->input(0));
auto out_abs = grad_param->out->ToAbstract()->Broaden();
if (grad_param->grad_by_value) {
for (size_t i = 0; i < grad_param->op_args.size(); ++i) {
const auto &v = grad_param->op_args[i];
auto node = grad_param->cnode->input(i + 1);
if (node->isa<Parameter>()) {
node_list.emplace_back(node);
node->set_abstract(v->ToAbstract());
(void)node_list.emplace_back(node);
node->set_abstract(v->ToAbstract()->Broaden());
continue;
}
auto v_node = NewValueNode(grad_param->op_args[i]);
v_node->set_abstract(grad_param->op_args[i]->ToAbstract());
node_list.emplace_back(v_node);
v_node->set_abstract(grad_param->op_args[i]->ToAbstract()->Broaden());
(void)node_list.emplace_back(v_node);
}
// Set out
auto out_node = NewValueNode(grad_param->out);
out_node->set_abstract(out_abs);
(void)node_list.emplace_back(out_node);
} else {
// Input is a Parameter or cnode, not a value node
BuildKNodeListFromPrimalCNode(grad_param->cnode, grad_param->op_args, &node_list);
// Set out
MS_EXCEPTION_IF_NULL(variable_adjoint);
(void)node_list.emplace_back(variable_adjoint->k_node());
}
auto out_node = NewValueNode(grad_param->out);
auto out_abs = grad_param->out->ToAbstract()->Broaden();
out_node->set_abstract(out_abs);
// set out
node_list.emplace_back(out_node);
// set dout
node_list.emplace_back(dout);
// Set dout
(void)node_list.emplace_back(dout);
auto input_node = tape_->NewCNode(node_list);
input_node->set_abstract(out_abs);
return input_node;
@ -511,7 +530,7 @@ void AutoGradCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, cons
std::vector<AnfNodePtr> *const node_list) {
MS_EXCEPTION_IF_NULL(cnode);
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
MS_LOG(DEBUG) << "Find input knode of node " << cnode->input(i)->DebugString();
MS_LOG(DEBUG) << "Get knode for node " << cnode->input(i)->DebugString();
if (cnode->input(i)->isa<CNode>()) {
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(cnode->input(i));
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
@ -520,13 +539,13 @@ void AutoGradCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, cons
MS_EXCEPTION_IF_NULL(input_adjoint_iter->second->k_node());
(void)node_list->emplace_back(input_adjoint_iter->second->k_node());
} else {
cnode->input(i)->set_abstract(op_args[i - 1]->ToAbstract());
cnode->input(i)->set_abstract(op_args[i - 1]->ToAbstract()->Broaden());
(void)node_list->emplace_back(cnode->input(i));
}
}
}
void AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param, const VariableNodePtr &VariableNode) {
void AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param, const VariableAdjointPtr &variable_adjoint) {
MS_EXCEPTION_IF_NULL(grad_param);
AnfNodePtrList node_list;
for (size_t i = 0; i < grad_param->cnode->inputs().size(); ++i) {
@ -534,7 +553,8 @@ void AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param, const Variable
}
auto k_node = tape_->NewCNode(node_list);
k_node->set_abstract(grad_param->out->ToAbstract()->Broaden());
VariableNode->set_k_node(k_node);
variable_adjoint->set_k_node(k_node);
MS_LOG(DEBUG) << "Build knode " << k_node->DebugString();
}
AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const AnfNodePtr &input_node) {
@ -542,7 +562,7 @@ AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const AnfNodePtr &input_nod
if (input_node->isa<CNode>()) {
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(input_node);
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
MS_LOG(EXCEPTION) << "cannot find input in adjoint map, inp: " << input_node->DebugString();
MS_LOG(EXCEPTION) << "Cannot find input in adjoint map, inp: " << input_node->DebugString();
}
return input_adjoint_iter->second->k_node();
} else {
@ -561,8 +581,9 @@ void AutoGradCellImpl::UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr
MS_LOG(EXCEPTION) << "The size of dins is not same as op_args";
}
for (size_t i = 0; i < op_args.size(); ++i) {
auto node = cnode->input(i + 1);
auto din = dins[i];
const auto &node = cnode->input(i + 1);
const auto &din = dins[i];
MS_LOG(DEBUG) << "Node " << node->DebugString() << ", din " << din->DebugString();
UpdateNextEdges(fn, node, din, op_args[i]);
}
}
@ -613,29 +634,27 @@ void AutoGradCellImpl::UpdateNextEdges(const FunctionNodePtr &fn, const AnfNodeP
AddParameterNode(param, tensor);
UpdateNextEdges(fn, node, din, op_arg);
} else {
MS_LOG(DEBUG) << "it is not a cnode: " << node->DebugString();
MS_LOG(DEBUG) << "It is not a cnode or parameter: " << node->DebugString();
return;
}
}
void AutoGradCellImpl::BuildForwardLastNode() {
MS_EXCEPTION_IF_NULL(last_node_);
if (last_node_->isa<ValueNode>() ||
anfnode_to_variable_adjoint_.find(last_node_) != anfnode_to_variable_adjoint_.end()) {
return;
}
if (anfnode_to_variable_adjoint_.find(last_node_) == anfnode_to_variable_adjoint_.end()) {
auto zeros_like_node = BuildZerosLikeNode(tape_, sens_value_);
auto fn = std::make_shared<FunctionNode>(tape_, zeros_like_node);
// If last_node is a maketuple or tuplegetitem, need update next edges,
// if last_node is parameter, not need to update next edges.
if (last_node_->isa<CNode>()) {
UpdateNextEdges(fn, last_node_, zeros_like_node, sens_value_);
}
auto input_adjoint = std::make_shared<VariableNode>(fn, sens_value_);
anfnode_to_variable_adjoint_.insert(std::make_pair(last_node_, input_adjoint));
} else {
MS_LOG(EXCEPTION) << "unprocessed node" << last_node_->DebugString();
MS_LOG(DEBUG) << "Process last node info " << last_node_->DebugString();
auto zeros_like_node = BuildZerosLikeNode(tape_, sens_value_);
auto fn = std::make_shared<FunctionNode>(tape_, zeros_like_node);
// If last_node is a maketuple or tuplegetitem, need update next edges,
// if last_node is parameter, not need to update next edges.
if (last_node_->isa<CNode>()) {
UpdateNextEdges(fn, last_node_, zeros_like_node, sens_value_);
}
auto input_adjoint = std::make_shared<VariableAdjoint>(fn, sens_value_);
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(last_node_, input_adjoint));
}
void AutoGradCellImpl::AddParameterNode(const AnfNodePtr &parameter, const ValuePtr &tensor) {
@ -643,9 +662,9 @@ void AutoGradCellImpl::AddParameterNode(const AnfNodePtr &parameter, const Value
MS_EXCEPTION_IF_NULL(tensor);
auto zeros_like_dout = BuildZerosLikeNode(tape_, tensor);
auto func_node = std::make_shared<FunctionNode>(tape_, zeros_like_dout);
auto input_adjoint = std::make_shared<VariableNode>(func_node, tensor);
anfnode_to_variable_adjoint_.insert(std::make_pair(parameter, input_adjoint));
weights_.push_back(parameter);
auto input_adjoint = std::make_shared<VariableAdjoint>(func_node, tensor);
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(parameter, input_adjoint));
(void)weights_used_in_graph_.emplace_back(parameter);
}
AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePtr &out_value, const ValuePtr &sub_value,
@ -653,8 +672,8 @@ AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePt
MS_EXCEPTION_IF_NULL(out_value);
MS_EXCEPTION_IF_NULL(sub_value);
MS_EXCEPTION_IF_NULL(din);
std::string out_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(out_value);
std::string sub_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(sub_value);
const auto &out_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(out_value);
const auto &sub_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(sub_value);
if (out_value_id == sub_value_id) {
return din;
} else if (out_value->isa<tensor::Tensor>()) {
@ -668,13 +687,13 @@ AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePt
}
auto value_seq = out_value->cast<ValueSequencePtr>();
int index = -1;
for (auto value : value_seq->value()) {
for (const auto &value : value_seq->value()) {
auto real_din = GetRealDin(fn, value, sub_value, din);
(void)inputs.emplace_back(real_din);
// if exist din == fake_dout, we record it in user vector
if (din == fn->fake_dout() && real_din == din) {
index = inputs.size() - 1;
index = static_cast<int>(inputs.size()) - 1;
}
}
auto new_din = tape_->NewCNode(inputs);
@ -700,7 +719,7 @@ void AutoGradCellImpl::BuildBPropCutCNode(const CNodePtr &cnode, std::vector<CNo
prim_py->AddBpropCutPrim(bprop_cut);
if (prim->HasAttr("cell_id")) {
auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
if (cell_id != "") {
if (!cell_id.empty()) {
(void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
(void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id));
}
@ -719,12 +738,11 @@ void AutoGradCellImpl::BuildBPropCutCNode(const CNodePtr &cnode, std::vector<CNo
if (i < args_size) {
auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), output, NewValueNode(SizeToLong(i - 1))});
din->set_abstract(cnode->input(i)->abstract()->Broaden());
outputs->emplace_back(din);
(void)outputs->emplace_back(din);
(void)abs.emplace_back(din->abstract());
}
}
output->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
return;
}
void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs) {
@ -755,11 +773,7 @@ void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<
}
void AutoGradCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg) {
MS_EXCEPTION_IF_NULL(last_node_);
MS_LOG(DEBUG) << "Last node info " << last_node_->DebugString();
BuildForwardLastNode();
// Add sens parameter
ParameterPtr sens_param = nullptr;
if (has_sens_arg) {
@ -769,8 +783,9 @@ void AutoGradCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool has
}
// update dout for dout
MS_EXCEPTION_IF_NULL(last_node_);
if (anfnode_to_variable_adjoint_.find(last_node_) != anfnode_to_variable_adjoint_.end()) {
auto variable = anfnode_to_variable_adjoint_.at(last_node_);
const auto &variable = anfnode_to_variable_adjoint_.at(last_node_);
if (has_sens_arg && sens_param != nullptr) {
variable->fn()->UpdateAccumulativeDout(sens_param);
} else {
@ -782,19 +797,19 @@ void AutoGradCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool has
need_grad_weights_.clear();
for (const auto &weight : weights) {
TraceGuard trace_guard(std::make_shared<TraceCopy>(weight->debug_info()));
auto p = tape_->add_parameter();
(void)need_grad_weights_.emplace(weight);
auto input_w = weight->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(input_w);
// Use name to match weight parameter in high order
auto default_param = input_w->default_param();
p->set_name(input_w->name());
p->set_default_param(default_param);
p->set_abstract(default_param->ToAbstract()->Broaden());
auto t = pynative::PyNativeAlgo::Common::GetTensorFromParam(weight);
(void)need_grad_weights_.emplace(t->id());
auto p = tape_->add_parameter();
auto param = weight->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param);
p->set_name(param->name());
p->set_default_param(t);
p->set_abstract(t->ToAbstract()->Broaden());
}
}
OrderedMap<AnfNodePtr, VariableNodePtr>::reverse_iterator AutoGradCellImpl::GetLastNodeReverseIter() {
OrderedMap<AnfNodePtr, VariableAdjointPtr>::reverse_iterator AutoGradCellImpl::GetLastNodeReverseIter() {
for (auto iter = anfnode_to_variable_adjoint_.rbegin(); iter != anfnode_to_variable_adjoint_.rend(); ++iter) {
if (!iter->first->isa<CNode>()) {
continue;
@ -810,33 +825,40 @@ OrderedMap<AnfNodePtr, VariableNodePtr>::reverse_iterator AutoGradCellImpl::GetL
void AutoGradCellImpl::BackPropagate() {
const auto &last_node_reverse_iter = GetLastNodeReverseIter();
bool has_primc = false;
for (auto iter = last_node_reverse_iter; iter != anfnode_to_variable_adjoint_.rend(); ++iter) {
MS_LOG(DEBUG) << "BackPropagate cnode: " << iter->first->DebugString();
auto variable = iter->second;
const auto &variable = iter->second;
if (!variable->is_need_propagate()) {
MS_LOG(DEBUG) << "No need grad";
continue;
}
if (variable->is_need_propagate() && variable->is_fake_bprop()) {
MS_LOG(EXCEPTION) << variable->fake_prim_name() << " op has not corresponding bprop!";
if (variable->is_fake_bprop()) {
MS_LOG(WARNING) << variable->fake_prim_name() << " op has not corresponding bprop!";
continue;
}
auto fn = variable->fn();
if (!has_primc && iter->first->isa<CNode>() && GetCNodePrimitive(iter->first) != nullptr) {
has_primc = true;
}
const auto &fn = variable->fn();
// replace real dout to fake dout
Replace(fn->fake_dout(), fn->RealDout());
// replace edges which exist fake dout
fn->ReplaceEdges();
auto &next_edges = fn->next_edges();
const auto &next_edges = fn->next_edges();
for (const auto &next_edge : next_edges) {
auto node = next_edge.first;
auto din = next_edge.second;
const auto &node = next_edge.first;
const auto &din = next_edge.second;
if (anfnode_to_variable_adjoint_.find(node) == anfnode_to_variable_adjoint_.end()) {
MS_LOG(EXCEPTION) << "current node not find corresponding node";
MS_LOG(EXCEPTION) << "Current node not find corresponding node";
}
auto last_variable = anfnode_to_variable_adjoint_[node];
last_variable->fn()->UpdateAccumulativeDout(din);
last_variable->set_is_need_propagate(true);
}
}
tape_->set_flag(kPrimCPrimPyMixed, has_primc && need_do_manager_replace_);
}
AnfNodePtr AutoGradCellImpl::GetGradNodeByIndex(const AnfNodePtrList &node_list, size_t index) {
@ -850,7 +872,7 @@ AnfNodePtr AutoGradCellImpl::GetGradNodeByIndex(const AnfNodePtrList &node_list,
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
// If weight is not used in the forward network, just return zeros_like() as dout.
if (grad_node->isa<Parameter>()) {
MS_LOG(WARNING) << "Weight does not participate in forward calculation, weight: " << grad_node->DebugString();
MS_LOG(INFO) << "Weight does not participate in forward calculation, weight: " << grad_node->DebugString();
auto w = grad_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(w);
auto default_param = w->default_param();
@ -920,34 +942,6 @@ AnfNodePtr AutoGradCellImpl::GetWeightGrad(bool grad_weights, const AnfNodePtrLi
}
}
bool AutoGradCellImpl::IsOutputBothEmpty(const AnfNodePtr &inputs_grad, const AnfNodePtr &weights_grad) const {
if (!inputs_grad->isa<CNode>() || !weights_grad->isa<CNode>()) {
return false;
}
auto inputs_grad_cnode = inputs_grad->cast<CNodePtr>();
auto weights_grad_cnode = weights_grad->cast<CNodePtr>();
if (!IsPrimitiveCNode(inputs_grad_cnode, prim::kPrimMakeTuple) ||
!IsPrimitiveCNode(weights_grad_cnode, prim::kPrimMakeTuple)) {
return false;
}
constexpr int kEmptyTupeSize = 1;
if (inputs_grad_cnode->size() != kEmptyTupeSize || weights_grad_cnode->size() != kEmptyTupeSize) {
return false;
}
return true;
}
AnfNodePtr AutoGradCellImpl::GenerateEmptyTupleValue() {
std::vector<ValuePtr> value_list;
auto inputs_value = std::make_shared<ValueTuple>(value_list);
auto weights_value = std::make_shared<ValueTuple>(value_list);
std::vector<ValuePtr> tuple_list{inputs_value, weights_value};
auto tuple_value = std::make_shared<ValueTuple>(tuple_list);
auto tuple_value_node = NewValueNode(tuple_value);
tuple_value_node->set_abstract(tuple_value->ToAbstract());
return tuple_value_node;
}
void AutoGradCellImpl::SetOutput(const AnfNodePtrList &weights, const std::vector<size_t> &grad_position,
const GradAttr &grad_attr) {
auto inputs_grad_ret = GetInputGrad(grad_attr.grad_all_inputs, grad_attr.get_by_position, grad_position);
@ -1012,8 +1006,8 @@ void AutoGradCellImpl::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new
}
void AutoGradCellImpl::ElimateTupleGetItem() {
for (auto iter = users_.begin(); iter != users_.end(); iter++) {
auto old_node = iter->first;
for (auto &user : users_) {
auto old_node = user.first;
if (!old_node->isa<CNode>()) {
continue;
}
@ -1034,6 +1028,7 @@ void AutoGradCellImpl::ElimateTupleGetItem() {
void AutoGradCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg) {
const auto &parameters = tape_->parameters();
auto cell_inputs_size = cell_inputs_.size();
pynative::PyNativeAlgo::Common::DumpGraphIR("replace_param.ir", tape_);
if (need_do_manager_replace_) {
MS_LOG(DEBUG) << "Do parameter replace by manager";
auto mng = MakeManager({tape_}, false);
@ -1065,13 +1060,12 @@ void AutoGradCellImpl::ReplacePrimalParameter(const AnfNodePtrList &weights, boo
}
}
for (auto &weight : weights_) {
if (need_grad_weights_.find(weight) == need_grad_weights_.end()) {
auto parameter = weight->cast<ParameterPtr>();
const auto &input_value = parameter->default_param();
MS_EXCEPTION_IF_NULL(input_value);
auto value_node = NewValueNode(input_value);
value_node->set_abstract(input_value->ToAbstract()->Broaden());
for (auto &weight : weights_used_in_graph_) {
auto t = pynative::PyNativeAlgo::Common::GetTensorFromParam(weight);
if (need_grad_weights_.find(t->id()) == need_grad_weights_.end()) {
MS_LOG(DEBUG) << "Convert " << weight->DebugString() << " to value node";
auto value_node = NewValueNode(t);
value_node->set_abstract(t->ToAbstract()->Broaden());
Replace(weight, value_node);
}
}

View File

@ -86,9 +86,9 @@ class FunctionNode {
};
using FunctionNodePtr = std::shared_ptr<FunctionNode>;
class VariableNode {
class VariableAdjoint {
public:
VariableNode(const FunctionNodePtr &fn, const ValuePtr &out_value) : fn_(fn), out_value_(out_value) {}
VariableAdjoint(const FunctionNodePtr &fn, const ValuePtr &out_value) : fn_(fn), out_value_(out_value) {}
ValuePtr out_value() const { return out_value_; }
FunctionNodePtr fn() const { return fn_; }
@ -114,7 +114,7 @@ class VariableNode {
// K mapped cnode for primal CNode; primal CNode is owned by primal funcgraph, this is owned by tape_;
AnfNodePtr k_node_{nullptr};
};
using VariableNodePtr = std::shared_ptr<VariableNode>;
using VariableAdjointPtr = std::shared_ptr<VariableAdjoint>;
class AutoGradCellImpl {
public:
@ -143,10 +143,10 @@ class AutoGradCellImpl {
// Top cell inputs
AnfNodePtrList cell_inputs_;
// These weights need to calculate gradient.
mindspore::HashSet<AnfNodePtr> need_grad_weights_;
mindspore::HashSet<std::string> need_grad_weights_;
// Bprop dins of each variable or middle out
OrderedMap<AnfNodePtr, VariableNodePtr> anfnode_to_variable_adjoint_;
AnfNodePtrList weights_;
OrderedMap<AnfNodePtr, VariableAdjointPtr> anfnode_to_variable_adjoint_;
AnfNodePtrList weights_used_in_graph_;
// Record cnode's input map for tape_
UserType users_;
// Flag for ms_funtcion and high order
@ -156,7 +156,8 @@ class AutoGradCellImpl {
std::vector<bool> GetNeedGradFlags(const CNodePtr &cnode);
// construct input as cnode for expander
CNodePtr ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout);
CNodePtr ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout,
const VariableAdjointPtr &variable_adjoint);
// Back propagate for one node;
void UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr &cnode, const std::vector<CNodePtr> &dins,
const ValuePtrList &op_args);
@ -176,7 +177,7 @@ class AutoGradCellImpl {
// Set sens and weights parameter nodes by user input info
void SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg);
// get last reverse iterator
OrderedMap<AnfNodePtr, VariableNodePtr>::reverse_iterator GetLastNodeReverseIter();
OrderedMap<AnfNodePtr, VariableAdjointPtr>::reverse_iterator GetLastNodeReverseIter();
void BackPropagate();
// Set return node according to grad flag
@ -184,14 +185,12 @@ class AutoGradCellImpl {
AnfNodePtr GetGradNodeByIndex(const AnfNodePtrList &node_list, size_t index);
AnfNodePtr GetInputGrad(bool grad_all_inputs, bool get_by_position, const std::vector<size_t> &grad_position);
AnfNodePtr GetWeightGrad(bool grad_weights, const AnfNodePtrList &weights, bool weight_param_is_tuple);
bool IsOutputBothEmpty(const AnfNodePtr &inputs_grad, const AnfNodePtr &weights_grad) const;
AnfNodePtr GenerateEmptyTupleValue();
void AddUser(const AnfNodePtr &node, const CNodePtr &user, size_t index);
void Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
void ElimateTupleGetItem();
// Fbprop
void BuildKNode(const GradParamPtr &grad_param, const VariableNodePtr &VariableNode);
void BuildKNode(const GradParamPtr &grad_param, const VariableAdjointPtr &variable_adjoint);
void BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const ValuePtrList &op_args,
std::vector<AnfNodePtr> *const node_list);
AnfNodePtr BuildKNodeForCNodeInput(const AnfNodePtr &input_node);

View File

@ -34,10 +34,12 @@ namespace mindspore {
/* namespace to support opt */
namespace opt {
namespace {
const std::map<std::string, std::vector<std::string>> op2attrs = {{prim::kPrimBroadcastTo->name(), {kAttrShape}},
{prim::kPrimReduceMax->name(), {kAttrKeepDims}},
{prim::kPrimReduceMin->name(), {kAttrKeepDims}},
{prim::kPrimReduceSum->name(), {kAttrKeepDims}}};
const std::map<std::string, std::vector<std::string>> op2attrs = {
{prim::kPrimBroadcastTo->name(), {kAttrShape}},
{prim::kPrimReduceMax->name(), {kAttrKeepDims}},
{prim::kPrimReduceMin->name(), {kAttrKeepDims}},
{prim::kPrimReduceSum->name(), {kAttrKeepDims}},
{prim::kPrimMatMul->name(), {kTransposeA, kTransposeB}}};
}
bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
@ -53,7 +55,7 @@ bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
continue;
}
parallel::OperatorAttrs attrs;
auto iter = op2attrs.find(primitive->name());
const auto iter = op2attrs.find(primitive->name());
if (iter != op2attrs.end()) {
for (auto &attr : iter->second) {
if (primitive->HasAttr(attr)) {

View File

@ -29,40 +29,32 @@ using mindspore::tensor::Tensor;
namespace mindspore {
namespace parallel {
std::string GetOpPythonPath(const OperatorName &op_name) {
// almost all ops are defined in two main paths
const std::string ops_module = OP_PATH;
const std::string inner_ops_module = INNER_OP_PATH;
const std::string grad_ops_module = GRAD_OP_PATH;
const std::string functional_op_module = FUNCTIONAL_OP_PATH;
py::module mod = py::module::import(common::SafeCStr(ops_module));
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
py::module grad_mod = py::module::import(common::SafeCStr(grad_ops_module));
py::module functional_mod = py::module::import(common::SafeCStr(functional_op_module));
const char *GetOpPythonPath(const char *op_name) {
static py::module inner_mod = py::module::import(INNER_OP_PATH);
if (py::hasattr(inner_mod, op_name)) {
return INNER_OP_PATH;
}
if (py::hasattr(inner_mod, common::SafeCStr(op_name))) {
return inner_ops_module;
static py::module mod = py::module::import(OP_PATH);
if (py::hasattr(mod, op_name)) {
return OP_PATH;
}
if (py::hasattr(mod, common::SafeCStr(op_name))) {
return ops_module;
static py::module grad_mod = py::module::import(GRAD_OP_PATH);
if (py::hasattr(grad_mod, op_name)) {
return GRAD_OP_PATH;
}
if (py::hasattr(grad_mod, common::SafeCStr(op_name))) {
return grad_ops_module;
static py::module functional_mod = py::module::import(FUNCTIONAL_OP_PATH);
if (!py::hasattr(functional_mod, op_name)) {
MS_LOG(EXCEPTION) << OP_PATH << " and " << INNER_OP_PATH << " and " << GRAD_OP_PATH << " and " << FUNCTIONAL_OP_PATH
<< " don't have op:" << op_name;
}
if (!py::hasattr(functional_mod, common::SafeCStr(op_name))) {
MS_LOG(EXCEPTION) << ops_module << " and " << inner_ops_module << " and " << grad_ops_module << " and "
<< functional_op_module << " don't have op:" << op_name;
}
return functional_op_module;
return FUNCTIONAL_OP_PATH;
}
ValuePtr CreateOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
std::string op_path = GetOpPythonPath(op_name);
py::module mod = py::module::import(common::SafeCStr(op_path));
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
MS_LOG(ERROR) << "Failure: op_path:" << op_path << " don't have attr " << op_name;
return nullptr;
}
const auto op_path = GetOpPythonPath(op_name.c_str());
std::vector<py::object> arg_list;
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list),
[](const Attr &attr) { return ValueToPyData(attr.second); });

View File

@ -31,7 +31,7 @@ namespace mindspore {
namespace parallel {
const char USING_HASH_NAME[] = "USING_HASH_NAME";
// Get the operator's path where the operator has be defined
std::string GetOpPythonPath(const OperatorName &op_name);
const char *GetOpPythonPath(const char *op_name);
// Init python operator Instance
ValuePtr CreateOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name);

View File

@ -498,6 +498,8 @@ constexpr auto kAttrReshapeType = "reshape_type";
constexpr auto kAttrAxis = "axis";
constexpr auto kAttrAxes = "axes";
constexpr auto kAttrKeepDims = "keep_dims";
constexpr auto kTransposeA = "transpose_a";
constexpr auto kTransposeB = "transpose_b";
constexpr auto kAttrSkipMode = "skip_mode";
constexpr auto kAttrShapeGamma = "shape_gamma";
constexpr auto kAttrPerm = "perm";
@ -718,6 +720,7 @@ constexpr auto kFlagNeedRenormalize = "need_renormalize";
constexpr auto kFlagEnableZeroCopyInGraph = "enable_zero_copy_in_graph";
constexpr auto kFlagUseDynamicShapeProcess = "use_dynamic_shape_process";
constexpr auto kFlagEnableRunGraphBySingleOp = "enable_run_graph_by_single_op";
constexpr auto kPrimCPrimPyMixed = "primc_primpy_mixed";
// TODO(dsj): for ms_function running in graph_mode. should be delete later
constexpr auto kAttrMSFunction = "ms_function_graph";

View File

@ -198,6 +198,8 @@ FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &resource) {
irpass.tuple_list_get_item_eliminator_,
irpass.tuple_list_set_item_eliminator_,
irpass.depend_value_elim_,
irpass.switch_simplify_,
irpass.ad_related_special_op_eliminate_,
});
OptPassGroupMap map({{"ad_final_opt", bg_final_opt}});
if (pynative::PyNativeExecutor::GetInstance()->grad_executor()->need_renormalize()) {
@ -250,6 +252,7 @@ FuncGraphPtr OptGradGraphPass(const ResourcePtr &resource) {
WITH(MsProfile::GetProfile()->Step("bprop_graph_final_opt"))[&bprop_graph_final_opt, &func_graph]() {
func_graph = bprop_graph_final_opt->step(func_graph, true);
};
func_graph = LiftingClone(func_graph);
// Validate(func_graph);
return func_graph;
}

View File

@ -88,12 +88,17 @@ struct InputArgsInfo {
bool has_custom_bprop;
size_t input_size;
std::string obj_id;
bool has_sens{false};
bool use_dynamic_shape_process = false;
bool grad_is_running{false};
bool use_dynamic_shape_process{false};
PrimitivePyPtr custom_bprp_prim{nullptr};
ValuePtr out_value{nullptr};
std::string cell_id;
std::string already_run_cell_id;
std::string input_args_id;
// Cell unique id, cell_id + cell_order;
std::string obj_order_id;
size_t custom_bprop_cell_count = 0;
size_t grad_order = 0;
std::vector<std::string> input_arg_id_vec;

View File

@ -161,9 +161,9 @@ void ForwardExecutor::RunOpForward(const FrontendOpRunInfoPtr &op_run_info) {
MS_LOG(DEBUG) << "RunOp name: " << op_run_info->base_op_run_info.op_name;
// 1.Set cast for inputs
SetCastForInputs(op_run_info);
// 2. Infer output abstract
// 2.Infer output abstract
InferOutputAbstract(op_run_info);
// 3. Run op with selected backend
// 3.Run op with selected backend
if (!op_run_info->output_get_by_infer_value) {
GetOutput(op_run_info);
}
@ -171,7 +171,6 @@ void ForwardExecutor::RunOpForward(const FrontendOpRunInfoPtr &op_run_info) {
MS_LOG(DEBUG) << "Grad flag is false";
return;
}
// Set forward output flag for release memory,
// Because tensor address may change, it should set in main thread to ensure consistency.
PyNativeAlgo::Common::SetForwardOutputFlag(op_run_info->out_value);
@ -181,7 +180,10 @@ void ForwardExecutor::RunOpForward(const FrontendOpRunInfoPtr &op_run_info) {
return;
}
// 4. Do op grad and record op info
if (!is_ms_function_compiling_) {
// If cell have custom bprop, no need do op grad. Otherwise, need do.
// If ms function is complime, real run op must be not record
bool is_custom_bprop = op_run_info->custom_bprop_cell_count <= 0;
if (!is_ms_function_compiling_ && is_custom_bprop) {
grad()->ProcessOpGradInfo(op_run_info);
}
}

File diff suppressed because it is too large Load Diff

View File

@ -82,6 +82,10 @@ class GradExecutor {
inline void set_use_dynamic_shape_process(bool use_dynamic_shape_process) {
use_dynamic_shape_process_ = use_dynamic_shape_process;
}
inline InputArgsInfoPtr top_input_args_info() const {
MS_EXCEPTION_IF_NULL(top_input_args_info_);
return top_input_args_info_;
}
inline bool need_renormalize() const { return need_renormalize_; }
inline void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); }
@ -101,28 +105,24 @@ class GradExecutor {
py::object CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj, const py::object &grad_hash_id,
const py::args &args);
TopCellInfoPtr GetTopCell(const std::string &already_run_cell_id);
void GetPreRunTopCell(const py::object &obj);
void ProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
void AsyncProcessOpGradInfo(const FrontendOpRunInfoPtr &op_run_info) const;
void EndGraphInner(const py::object &obj, const py::object &out, const py::args &args);
void EndGraphImpl(const InputArgsInfoPtr &input_args_info);
AnfNodePtr GetInput(const ValuePtr &v, const string &obj_id) const;
void AsyncEndGraphImpl(const InputArgsInfoPtr input_args_info);
AnfNodePtr GetParamInput(const ValuePtr &v, const std::string &id) const;
void UpdateForwardTensorInfoInBpropGraph(const FrontendOpRunInfoPtr &op_run_info) const;
void UpdatePreTensorInfo(const tensor::TensorPtr &new_tensor,
const std::vector<tensor::TensorPtr> &pre_tensors) const;
void ClearRes();
void WorkerJoin() { async_executor_->WorkerJoin(); }
void CheckGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node = false,
void CheckGraphDynamic(const CNodePtr &cnode, const size_t node_idx, bool is_ms_function_node = false,
const std::string &graph_phase = "") const;
private:
ForwardExecutorPtr forward() const;
inline FuncGraphPtr curr_g() const { return top_cell()->fg(); }
inline void PushHighOrderGraphStack(const TopCellInfoPtr &top_cell) { high_order_stack_.push(top_cell); }
std::string GetCurCellOrder() const;
void SetGradOrder(const std::string &cell_id);
void SetGradOrder(const std::string &obj_id);
void SaveOutputNodeMap(const std::string &obj_id, const FrontendOpRunInfoPtr &op_run_info,
const CNodePtr &cnode) const;
void DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode, const ValuePtr &op_out) const;
@ -158,7 +158,6 @@ class GradExecutor {
void HandleInputArgsForTopCell(const InputArgsInfoPtr &input_args_info, bool is_bprop_top) const;
void InitResourceAndDfBuilder(const InputArgsInfoPtr &cell_info);
void MakeNewTopGraph(const InputArgsInfoPtr &input_args_info);
void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph) const;
// Manage resource when run grad process.
bool IsBpropGraph(const std::string &cell_id) const;
@ -166,6 +165,11 @@ class GradExecutor {
InputArgsInfoPtr GetInputArgsInfo(const py::object &obj, const py::args &args);
void NewGraphImpl(const InputArgsInfoPtr &input_args_info);
void AsyncNewGraphImpl(const InputArgsInfoPtr &input_args_info);
void EndGraphInner(const py::object &obj, const py::object &out, const py::args &args);
void UpdateInputArgsInfo(const InputArgsInfoPtr &input_args_info, const py::object &obj, const py::object &out,
const py::args &args);
void EndGraphImpl(const InputArgsInfoPtr &input_args_info);
void AsyncEndGraphImpl(const InputArgsInfoPtr &input_args_info);
void SetForwardLastNodeInfo(const ValuePtr &v, const std::string &obj_id) const;
void GetCustomBpropPrim(const py::object &obj, const py::args &args, const py::object &out,
const InputArgsInfoPtr &input_args_info);
@ -189,14 +193,10 @@ class GradExecutor {
AnfNodePtr CreateTupleGetItemNode(const std::string &obj_id,
const std::pair<AnfNodePtr, std::vector<int64_t>> &out) const;
void SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
void SaveDynamicDetectNodeInfoInFirstTime(const CNodePtr &cnode, size_t node_idx, bool is_ms_function_node,
const std::string &graph_phase) const;
bool IsGraphDynamic(const CNodePtr &cnode, const size_t &node_idx, bool is_ms_function_node,
bool IsGraphDynamic(const CNodePtr &cnode, size_t node_idx, bool is_ms_function_node,
const std::string &graph_phase) const;
bool IsCnodeInputsDynamic(const DynamicDetectNodeInfoPtr &old_node_info,
const std::vector<AnfNodePtr> &new_anf_inputs) const;
bool IsDynamicDetectNodeInfoChange(const DynamicDetectNodeInfoPtr &old_node_info, const CNodePtr &new_cnode,
bool is_ms_function_node, const std::string &graph_phase) const;
bool grad_flag_{false};
bool grad_is_running_{false};
bool need_renormalize_{false};
@ -204,16 +204,11 @@ class GradExecutor {
mutable bool use_dynamic_shape_process_{false};
mutable bool is_cell_id_in_dynamic_detect_nodes_map_{false};
int custom_bprop_cell_count_{0};
// Used in sub thread
size_t cell_order_{0};
std::string cur_cell_id_{""};
size_t obj_order_{0};
// If grad_order=1, indicate first derivative; grad_order=2, indicate second derivative; ...
size_t grad_order_{0};
std::string grad_operation_;
TopCellInfoPtr top_cell_{nullptr};
TopCellInfoPtr pre_top_cell_{nullptr};
InputArgsInfoPtr top_input_args_info_{nullptr};
// Records every cell info for share, regardless of whether need construct grad graph
std::stack<InputArgsInfoPtr> input_args_info_stack_;

View File

@ -75,8 +75,8 @@ size_t GetOutputTensorNumForTuple(const CNodePtr &make_tuple) {
} // namespace
void MsFunction::RunReplace(const CNodePtr &added_make_tuple,
const std::vector<tensor::TensorPtr> &total_output_tensors,
const FuncGraphPtr &grad_graph) const {
const std::vector<tensor::TensorPtr> &total_output_tensors, const FuncGraphPtr &grad_graph,
bool is_dynamic_shape) const {
MS_EXCEPTION_IF_NULL(added_make_tuple);
size_t index = 0;
for (size_t i = 1; i < added_make_tuple->size(); ++i) {
@ -110,6 +110,17 @@ void MsFunction::RunReplace(const CNodePtr &added_make_tuple,
} else {
MS_LOG(EXCEPTION) << "The output value of forward cnode is empty, forward cnode info: " << cnode->ToString();
}
if (is_dynamic_shape) {
if (output_num == 1) {
output_vnode->set_abstract(new_values[0]->ToAbstract()->Broaden());
} else {
AbstractBasePtrList abs_list;
for (size_t j = 0; j < output_num; ++j) {
(void)abs_list.emplace_back(new_values[j]->ToAbstract()->Broaden());
}
output_vnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
}
}
MS_LOG(DEBUG) << "New output value node: " << output_vnode->ToString();
}
// Save op info with new tensors for current running ms_function func graph.
@ -119,9 +130,8 @@ void MsFunction::RunReplace(const CNodePtr &added_make_tuple,
}
}
void MsFunction::ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, const ValuePtr &added_out,
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph) const {
MS_EXCEPTION_IF_NULL(top_cell);
void MsFunction::ReplaceNewTensorsInGradGraph(const ValuePtr &added_out, const FuncGraphPtr &ms_func_graph,
const FuncGraphPtr &grad_graph, bool is_dynamic_shape) const {
MS_EXCEPTION_IF_NULL(ms_func_graph);
// Get added forward nodes.
auto merge_node = ms_func_graph->output();
@ -150,7 +160,7 @@ void MsFunction::ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, co
// The forward node in ms_function graph is created during compilation and is a
// placeholder(mindspore/ccsrc/frontend/optimizer/ad/pynative_dfunctor.cc).After running ms_function, need to update
// to real value.
RunReplace(added_make_tuple, total_output_tensors, grad_graph);
RunReplace(added_make_tuple, total_output_tensors, grad_graph, is_dynamic_shape);
}
void MsFunction::UpdateMsFunctionForwardTensors(const GradExecutor *grad_executor, const string &op_info,
@ -190,11 +200,10 @@ void MsFunction::GetInputArgsNode(const FrontendOpRunInfoPtr &op_run_info, AnfNo
for (size_t i = 0; i < op_run_info->input_size; ++i) {
const auto &input_i_value = op_run_info->input_value[i];
const auto &id = PyNativeAlgo::Common::GetIdByValue(input_i_value);
MS_LOG(DEBUG) << "The input " << i << " id " << id
<< " value of ms_function graph is: " << input_i_value->ToString();
const auto &input_i_node = grad_executor->GetInput(input_i_value, id);
MS_EXCEPTION_IF_NULL(input_i_node);
MS_LOG(DEBUG) << "The input " << i << " node of ms_function graph is: " << input_i_node->DebugString();
MS_LOG(DEBUG) << "The input " << i << " id " << id << " value is: " << input_i_value->ToString()
<< ", node is: " << input_i_node->DebugString();
(void)input_nodes->emplace_back(input_i_node);
}
}
@ -230,11 +239,11 @@ void MsFunction::GetWeightsNode(const FrontendOpRunInfoPtr &op_run_info, const G
} else {
top_cell->fg()->add_parameter(param);
param->debug_info()->set_name(param->name());
top_cell->SetParamNodeMapInGraphInfoMap(tensor_value->id(), param, true);
}
(void)new_params.emplace_back(param);
(void)input_nodes->emplace_back(param);
(void)op_run_info->input_value.emplace_back(tensor_value);
top_cell->SetParamNodeMapInGraphInfoMap(tensor_value->id(), param, true);
MS_LOG(DEBUG) << "Top graph set free parameter " << param->DebugString() << ". Its default value is "
<< tensor_value->ToString() << ". Its name is: " << param->name();
}
@ -254,7 +263,6 @@ void MsFunction::MakeCNodeForMsFunction(const FrontendOpRunInfoPtr &op_run_info,
// Make a CNode which includes ms_function fprop graph and inputs node
MS_EXCEPTION_IF_NULL(ms_function_cnode);
*ms_function_cnode = grad_executor->top_cell()->fg()->NewCNode(input_nodes);
MS_LOG(DEBUG) << "Make ms function forward CNode: " << (*ms_function_cnode)->DebugString();
}
@ -291,7 +299,7 @@ void MsFunction::AsyncKPynativeWithFProp(const GradExecutor *grad_executor,
const ad::AutoGradCellImplPtr &auto_grad_cell_ptr,
const ad::GradParamPtr &grad_param) const {
MS_EXCEPTION_IF_NULL(grad_executor);
const auto fn = [this, grad_param, auto_grad_cell_ptr]() {
const auto fn = [grad_param, auto_grad_cell_ptr]() {
MS_EXCEPTION_IF_NULL(auto_grad_cell_ptr);
if (!auto_grad_cell_ptr->KPynativeWithFProp(grad_param)) {
MS_LOG(EXCEPTION) << "Failed to make adjoint for ms_function cnode";
@ -317,6 +325,7 @@ void MsFunction::GradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, co
MS_EXCEPTION_IF_NULL(op_run_info);
MS_EXCEPTION_IF_NULL(grad_executor);
MS_LOG(DEBUG) << "ms_function actual output value: " << op_run_info->out_value->ToString();
// Step 1: Update actual output tensors used in grad graph.
MS_EXCEPTION_IF_NULL(op_run_info->out_value);
MS_LOG(DEBUG) << "ms_function actual output value: " << op_run_info->out_value->ToString();
@ -328,8 +337,12 @@ void MsFunction::GradMsFunctionInner(const FrontendOpRunInfoPtr &op_run_info, co
grad_executor->top_cell()->op_info_with_ms_func_forward_tensors().end()) {
UpdateMsFunctionForwardTensors(grad_executor, op_run_info->op_info, added_out_v);
}
ReplaceNewTensorsInGradGraph(grad_executor->top_cell(), added_out_v, ms_func_graph, grad_graph);
bool is_dynamic_shape = common::AnfAlgo::IsDynamicShape(ms_func_graph->output());
if (is_dynamic_shape) {
const_cast<GradExecutor *>(grad_executor)->set_use_dynamic_shape_process(true);
MS_LOG(DEBUG) << "Ms function is dynamic shape";
}
ReplaceNewTensorsInGradGraph(added_out_v, ms_func_graph, grad_graph, is_dynamic_shape);
// Clone new ms_function func graph and grad graph.
auto new_ms_func_graph = BasicClone(ms_func_graph);

View File

@ -49,9 +49,9 @@ class MsFunction {
const ad::GradParamPtr &grad_param) const;
// Update device address of value node in grad graph by forward tensors.
void RunReplace(const CNodePtr &added_make_tuple, const std::vector<tensor::TensorPtr> &total_output_tensors,
const FuncGraphPtr &grad_graph) const;
void ReplaceNewTensorsInGradGraph(const TopCellInfoPtr &top_cell, const ValuePtr &added_out,
const FuncGraphPtr &ms_func_graph, const FuncGraphPtr &grad_graph) const;
const FuncGraphPtr &grad_graph, bool is_dynamic_shape) const;
void ReplaceNewTensorsInGradGraph(const ValuePtr &added_out, const FuncGraphPtr &ms_func_graph,
const FuncGraphPtr &grad_graph, bool is_dynamic_shape) const;
void UpdateMsFunctionForwardTensors(const GradExecutor *grad_executor, const string &op_info,
const ValuePtr &new_forward_value) const;
// Make CNode for ms_function forward graph.

View File

@ -91,6 +91,12 @@ void TopCellInfo::GetOpInfo(const FrontendOpRunInfoPtr &op_run_info) {
++op_index_;
}
void TopCellInfo::UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compile) {
need_compile_graph_ = need_compile_graph;
forward_already_run_ = forward_already_run;
vm_compile_ = vm_compile;
}
void TopCellInfo::ClearDeviceMemory() const {
MS_LOG(DEBUG) << "Clear device memory in value nodes of bprop graph, top cell: " << cell_id_;
auto ms_context = MsContext::GetInstance();
@ -127,6 +133,23 @@ void TopCellInfo::ClearDeviceMemory() const {
}
}
void TopCellInfo::Clear() {
MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_;
hook_changed_ = false;
is_init_kpynative_ = false;
need_compile_graph_ = false;
forward_already_run_ = false;
vm_compile_ = false;
op_index_ = 0;
resource_ = nullptr;
fg_ = nullptr;
graph_info_map_.clear();
op_info_with_tensor_id_.clear();
tensor_id_with_tensor_object_.clear();
op_info_with_ms_func_forward_tensors_.clear();
cnode_hash_with_op_index_.clear();
}
void TopCellInfo::DeleteParamNodeInfo(const FuncGraphPtr &g, const std::string &id) {
auto &graph_info = graph_info_map().at(g);
MS_EXCEPTION_IF_NULL(graph_info);
@ -188,22 +211,6 @@ void TopCellInfo::SetNestedMultipleOutputToGraphInfoMap(const string &id, const
}
}
void TopCellInfo::Clear() {
MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_;
hook_changed_ = false;
is_init_kpynative_ = false;
need_compile_graph_ = false;
forward_already_run_ = false;
op_index_ = 0;
resource_ = nullptr;
fg_ = nullptr;
graph_info_map_.clear();
op_info_with_tensor_id_.clear();
tensor_id_with_tensor_object_.clear();
op_info_with_ms_func_forward_tensors_.clear();
cnode_hash_with_op_index_.clear();
}
void TopCellInfo::SetUnpackOutputToGraphInfoMap(const std::string &id, const AnfNodePtr &node,
const std::vector<int64_t> &index) const {
auto &graph_info = graph_info_map().at(fg());

View File

@ -58,11 +58,11 @@ using GraphInfoPtr = std::shared_ptr<GraphInfo>;
class TopCellInfo {
public:
~TopCellInfo() = default;
TopCellInfo(bool is_high_order_top_cell, size_t grad_order, std::string c_cell_id, std::string cellid,
TopCellInfo(bool is_high_order_top_cell, size_t grad_order, std::string obj_id_with_grad_order, std::string cellid,
std::string already_run_cell_id, pipeline::ResourcePtr r, FuncGraphPtr fg)
: is_high_order_top_cell_(is_high_order_top_cell),
grad_order_(grad_order),
c_cell_id_(std::move(c_cell_id)),
obj_id_with_grad_order_(std::move(obj_id_with_grad_order)),
cell_id_(std::move(cellid)),
already_run_cell_id_(std::move(already_run_cell_id)),
resource_(std::move(r)),
@ -79,11 +79,12 @@ class TopCellInfo {
inline void ClearCellHookOp() { cell_backward_hook_op_.clear(); }
inline bool forward_already_run() const { return forward_already_run_; }
inline void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
inline bool need_compile_graph() const { return need_compile_graph_; }
inline void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
inline bool vm_compile() const { return vm_compile_; }
inline bool is_high_order_top_cell() const { return is_high_order_top_cell_; }
inline void set_need_do_final_opt(bool need_do_final_opt) { need_do_final_opt_ = need_do_final_opt; }
inline bool need_do_final_opt() const { return need_do_final_opt_; }
inline bool need_compile_graph() const { return need_compile_graph_; }
inline void set_need_compile_graph(bool need_compile_graph) { need_compile_graph_ = need_compile_graph; }
inline pipeline::ResourcePtr resource() const { return resource_; }
inline FuncGraphPtr fg() const {
MS_EXCEPTION_IF_NULL(fg_);
@ -91,7 +92,7 @@ class TopCellInfo {
}
inline void set_fg(const FuncGraphPtr &fg) { fg_ = fg; }
inline const std::string &cell_id() const { return cell_id_; }
inline const std::string &c_cell_id() const { return c_cell_id_; }
inline const std::string &obj_id_with_grad_order() const { return obj_id_with_grad_order_; }
inline const std::string &already_run_cell_id() const { return already_run_cell_id_; }
inline void set_input_args_id(const std::string &input_args_id) { input_args_id_ = input_args_id; }
inline const std::string &input_args_id() const { return input_args_id_; }
@ -124,10 +125,11 @@ class TopCellInfo {
inline void set_cnode_hash_with_op_index(const size_t &node_hash, const size_t &op_index) {
cnode_hash_with_op_index_[node_hash] = op_index;
}
inline size_t get_op_index_by_cnode_hash(const size_t &node_hash) {
auto iter = cnode_hash_with_op_index_.find(node_hash);
inline size_t get_op_index_by_cnode_hash(const size_t node_hash, const size_t node_idx) const {
const auto iter = cnode_hash_with_op_index_.find(node_hash);
if (iter == cnode_hash_with_op_index_.end()) {
MS_LOG(EXCEPTION) << "hash:" << node_hash << " is not found in cnode_hash_with_op_index_";
MS_LOG(DEBUG) << "hash:" << node_hash << " is not found in cnode_hash_with_op_index_";
return node_idx;
}
return iter->second;
}
@ -137,7 +139,8 @@ class TopCellInfo {
void DeleteParamNodeInfo(const FuncGraphPtr &g, const std::string &id);
void SetParamNodeMapInGraphInfoMap(const std::string &id, const ParameterPtr &param, bool is_weight = false) const;
void SetNodeMapInGraphInfoMap(const std::string &id, const AnfNodePtr &node, int64_t index = -1,
bool save_flag = true) const;
bool need_save_sub_id = true) const;
void UpdateTopCellInfo(bool forward_already_run, bool need_compile_graph, bool vm_compile);
void ClearDeviceMemory() const;
private:
@ -150,12 +153,13 @@ class TopCellInfo {
bool hook_changed_{false};
bool is_init_kpynative_{false};
bool forward_already_run_{false};
bool need_compile_graph_{false};
bool vm_compile_{false};
bool is_high_order_top_cell_{false};
bool need_do_final_opt_{false};
bool need_compile_graph_{false};
size_t op_index_{0};
size_t grad_order_{0};
std::string c_cell_id_;
std::string obj_id_with_grad_order_;
std::string cell_id_;
std::string already_run_cell_id_;
std::string input_args_id_;

View File

@ -213,6 +213,7 @@ void PyNativeExecutor::SetMsFunctionCompileStatus(bool is_compiling) const {
void PyNativeExecutor::SetDynamicInput(const py::object &cell, const py::args &args) const {
grad_executor()->set_use_dynamic_shape_process(true);
MS_LOG(DEBUG) << "Set dynamic shape by set inputs";
}
void RegPyNativeExecutor(const py::module *m) {

View File

@ -29,6 +29,18 @@
namespace mindspore {
namespace pynative {
namespace PyNativeAlgo {
namespace {
void ClonePrim(const FrontendOpRunInfoPtr &op_run_info) {
// Clone a new prim
MS_EXCEPTION_IF_NULL(op_run_info);
op_run_info->op_prim = std::make_shared<PrimitivePy>(*(op_run_info->op_prim));
MS_EXCEPTION_IF_NULL(op_run_info->op_prim->adapter());
if (op_run_info->op_prim->adapter()->attached_primitive() == nullptr) {
op_run_info->op_prim->adapter()->set_attached_primitive(op_run_info->op_prim);
}
}
} // namespace
std::string Common::GetIdByValue(const ValuePtr &v) {
MS_EXCEPTION_IF_NULL(v);
if (v->isa<tensor::Tensor>()) {
@ -101,6 +113,26 @@ bool Common::IsTensor(const ValuePtr &v, bool include_sequence) {
return v->isa<tensor::Tensor>() || v->isa<tensor::MetaSparseTensor>();
}
ValuePtr Common::FilterSensValues(const ValuePtr &value) {
MS_EXCEPTION_IF_NULL(value);
if (value->isa<tensor::Tensor>() || value->isa<tensor::COOTensor>() || value->isa<tensor::CSRTensor>()) {
return value;
} else if (value->isa<ValueSequence>()) {
std::vector<ValuePtr> value_list;
auto value_seq = value->cast<ValueSequencePtr>();
MS_EXCEPTION_IF_NULL(value_seq);
for (auto &filter_value : value_seq->value()) {
if (FilterSensValues(filter_value) != nullptr) {
(void)value_list.emplace_back(filter_value);
}
}
return std::make_shared<ValueTuple>(value_list);
} else {
MS_LOG(DEBUG) << "Value type: " << value->ToString();
return nullptr;
}
}
tensor::TensorPtr Common::GetTensorFromParam(const AnfNodePtr &param_node) {
MS_EXCEPTION_IF_NULL(param_node);
auto param = param_node->cast<ParameterPtr>();
@ -523,12 +555,8 @@ void DataConvert::GetInputTensor(const FrontendOpRunInfoPtr &op_run_info, const
bool need_convert_input_to_attr = NeedConvertConstInputToAttr(op_run_info, device_target, &input_to_attr);
MS_LOG(DEBUG) << "Need convert input to addr " << need_convert_input_to_attr;
if (need_convert_input_to_attr) {
// Clone a new prim
op_run_info->op_prim = std::make_shared<PrimitivePy>(*(op_run_info->op_prim));
MS_EXCEPTION_IF_NULL(op_run_info->op_prim->adapter());
if (op_run_info->op_prim->adapter()->attached_primitive() == nullptr) {
op_run_info->op_prim->adapter()->set_attached_primitive(op_run_info->op_prim);
}
// Prim may be changed attr
ClonePrim(op_run_info);
}
const auto &op_prim = op_run_info->op_prim;
@ -544,10 +572,17 @@ void DataConvert::GetInputTensor(const FrontendOpRunInfoPtr &op_run_info, const
// Mark tensors, common tensor data : 0, weight param: 1, valuenode(float_, int_): 2
ConvertValueToTensor(op_run_info, input_object, index, op_prim);
// -1 indicates input_object is not a dynInput
if (op_prim->HasAttr(kAttrDynInputSizes) && !input_object->isa<ValueSequence>()) {
auto dyn_v = GetValue<const std::vector<int64_t>>(op_prim->GetAttr(kAttrDynInputSizes));
(void)dyn_v.emplace_back(-1);
op_prim->set_attr(kAttrDynInputSizes, MakeValue(dyn_v));
if (op_prim->HasAttr(kAttrDynInputSizes)) {
if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE)) {
// Like addn, prim define in python, but number of inputs change, so the value of kAttrDynInputSizes
// changed too. In async, do opgrad may be not complete.
ClonePrim(op_run_info);
}
if (!input_object->isa<ValueSequence>()) {
auto dyn_v = GetValue<const std::vector<int64_t>>(op_prim->GetAttr(kAttrDynInputSizes));
(void)dyn_v.emplace_back(-1);
op_prim->set_attr(kAttrDynInputSizes, MakeValue(dyn_v));
}
}
}
op_prim->EndRecordAddAttr();

View File

@ -33,6 +33,7 @@ struct Common {
static std::string GetIdByValue(const ValuePtr &v);
static bool ValueHasDynamicShape(const ValuePtr &value);
static bool IsTensor(const ValuePtr &v, bool include_sequence = false);
static ValuePtr FilterSensValues(const ValuePtr &value);
static tensor::TensorPtr GetTensorFromParam(const AnfNodePtr &param_node);
static void SetForwardOutputFlag(const ValuePtr &v);
static void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph);

View File

@ -110,6 +110,7 @@ OpCompilerInfoPtr OpCompiler::Compile(const session::BackendOpRunInfoPtr &op_run
device::DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(op_run_info);
MS_EXCEPTION_IF_NULL(device_context);
py::gil_scoped_acquire acquire_gil;
auto graph_info = op_run_info->base_op_run_info.graph_info;
auto iter = op_compiler_infos_.find(graph_info);
// Check if the graph cache exists.
@ -160,7 +161,6 @@ OpCompilerInfoPtr OpCompiler::Compile(const session::BackendOpRunInfoPtr &op_run
auto op_compiler_info =
std::make_shared<OpCompilerInfo>(graph_info, graph->graph_id(), graph, outputs_with_index, device_context, false);
py::gil_scoped_acquire acquire_gil;
op_compiler_infos_[graph_info] = op_compiler_info;
return op_compiler_info;
}

View File

@ -225,3 +225,9 @@ def bprop_scalar_calc(x, y, out, dout):
def bprop_scalar_not(x, out, dout):
"""Backpropagator for primitive `bool_not` and `string_not`."""
return (C.zeros_like(x),)
@bprops.register("TensorMove")
def bprop_tensor_move(x, out, dout):
"""Backpropagator for primitive `TensorMove`."""
return (dout,)

View File

@ -1034,5 +1034,6 @@ class RandomShuffle(Primitive):
def __init__(self, seed=0, seed2=0):
"""Initialize RandomShuffle"""
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
self.add_prim_attr("side_effect_hidden", True)
Validator.check_non_negative_int(seed, "seed", self.name)
Validator.check_non_negative_int(seed2, "seed2", self.name)

View File

@ -213,9 +213,9 @@ FuncGraphManagerPtr Make_Manager(int64_t condition = 0) {
/// Description:
/// Expectation: the python path is right
TEST_F(TestStepParallel, GetPythonPath1) {
OperatorName operator_name = "AllReduce";
const char *operator_name = "AllReduce";
const std::string expect = "mindspore.ops.operations";
auto temp = parallel::GetOpPythonPath(operator_name);
std::string temp = parallel::GetOpPythonPath(operator_name);
ASSERT_EQ(temp, expect);
}
@ -223,9 +223,9 @@ TEST_F(TestStepParallel, GetPythonPath1) {
/// Description:
/// Expectation: the python path is right
TEST_F(TestStepParallel, GetPythonPath2) {
OperatorName operator_name = "Add";
const char *operator_name = "Add";
const std::string expect = "mindspore.ops.operations";
auto temp = parallel::GetOpPythonPath(operator_name);
std::string temp = parallel::GetOpPythonPath(operator_name);
ASSERT_EQ(temp, expect);
}