Support isolated side-effect nodes in constant returned func graph.

This commit is contained in:
Zhang Qinghua 2022-07-21 16:21:21 +08:00
parent 26c15ce6c2
commit f90dcc963d
28 changed files with 251 additions and 117 deletions

View File

@ -490,6 +490,75 @@ bool CombineLikeGraphs(const ResourcePtr &resource) {
return true;
}
namespace {
bool IsSideEffectCNode(const AnfNodePtr &node) {
const auto &primitive = GetCNodePrimitiveWithoutDoSignature(node);
if (primitive != nullptr) {
auto effect_info = GetPrimEffectInfo(primitive);
if (effect_info.memory || effect_info.io) {
MS_LOG(DEBUG) << "Side Effect Primitive CNode: " << node->DebugString();
return true;
}
}
return false;
}
bool HasIsolatedSideEffectNode(const FuncGraphPtr &func_graph) {
const auto node = func_graph->output();
if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
return false;
}
auto cnode = dyn_cast<CNode>(node);
MS_EXCEPTION_IF_NULL(cnode);
auto attr_sort_rhs_first = cnode->GetAttr(kAttrTopoSortRhsFirst);
auto sort_rhs_first =
attr_sort_rhs_first != nullptr && attr_sort_rhs_first->isa<BoolImm>() && GetValue<bool>(attr_sort_rhs_first);
if (!sort_rhs_first) {
// Return false if it's definitely not isolated Depend CNode.
return false;
}
// To check isolated nodes in {Depend -> StopGradient -> MakeTuple(...)}.
constexpr size_t stop_gradient_pos = 2;
auto stop_gradient_node = cnode->input(stop_gradient_pos);
auto stop_gradient_cnode = dyn_cast<CNode>(stop_gradient_node);
MS_EXCEPTION_IF_NULL(stop_gradient_cnode);
constexpr size_t isolated_node_pos = 1;
auto isolated_node = stop_gradient_cnode->input(isolated_node_pos);
if (IsPrimitiveCNode(isolated_node, prim::kPrimMakeTuple)) {
auto isolated_cnode = dyn_cast<CNode>(isolated_node);
MS_EXCEPTION_IF_NULL(isolated_cnode);
for (size_t i = 1; i < isolated_cnode->size(); ++i) {
if (IsSideEffectCNode(isolated_cnode->input(i))) {
MS_LOG(DEBUG) << "Multiple isolated side-effect node[" << i << "]: " << isolated_cnode->input(i)->DebugString();
return true;
}
}
} else {
if (IsSideEffectCNode(isolated_node)) {
MS_LOG(DEBUG) << "Single isolated side-effect node: " << isolated_node->DebugString();
return true;
}
}
return false;
}
void CheckIsolatedSideEffectNode(const FuncGraphPtr &func_graph) {
if (!HasIsolatedSideEffectNode(func_graph)) {
return;
}
auto new_return = func_graph->get_return();
new_return->set_has_isolated_side_effect_node(true);
func_graph->set_has_isolated_side_effect_node(true);
auto output_cnode = dyn_cast<CNode>(func_graph->output());
if (output_cnode != nullptr) {
output_cnode->set_has_isolated_side_effect_node(true);
}
MS_LOG(INFO) << "Set isolated side-effect node flag for " << func_graph->ToString();
}
} // namespace
bool SymbolResolveAction(const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource);
if (resource->manager() == nullptr) {
@ -500,12 +569,15 @@ bool SymbolResolveAction(const ResourcePtr &resource) {
MS_LOG(EXCEPTION) << "SymbolResolve error, graph is null";
}
bool ret = parse::ResolveFuncGraph(func_graph, resource);
// Remove unused nodes in cnode order list.
if (func_graph) {
// Remove unused nodes in cnode order list,
// and check isolated side-effect nodes.
if (func_graph != nullptr) {
func_graph->EraseUnusedNodeInOrder();
CheckIsolatedSideEffectNode(func_graph);
for (auto fg : func_graph->func_graphs_used_total()) {
if (fg) {
if (fg != nullptr) {
fg->EraseUnusedNodeInOrder();
CheckIsolatedSideEffectNode(fg);
}
}
}
@ -589,6 +661,9 @@ bool AbstractSpecializeAction(const ResourcePtr &resource) {
auto loaded_graph_ptr = GetLoadedGraph(resource);
// Abstract analyze
auto engine = resource->engine();
MS_EXCEPTION_IF_NULL(engine);
engine->set_check_isolated_side_effect(true);
AnalysisResult result = AbstractAnalyze(resource, resource->func_graph(), GetArgsAbs(resource));
// The top graph may be replaced by infer, update the top graph when the infer is done
@ -597,6 +672,7 @@ bool AbstractSpecializeAction(const ResourcePtr &resource) {
// Specialize
FuncGraphPtr new_fg = ProgramSpecialize(resource, result.context->func_graph(), result.context);
resource->set_func_graph(new_fg);
engine->set_check_isolated_side_effect(false);
// Remove unused nodes in cnode order list, this is prepared for auto-monad.
if (new_fg) {

View File

@ -646,7 +646,7 @@ void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
AnfNodePtr old_output = nullptr;
auto return_node = func_graph_->get_return();
if (return_node) {
if (return_node != nullptr) {
const size_t return_input_size = 2;
if (return_node->inputs().size() < return_input_size) {
MS_LOG(EXCEPTION) << "Length of inputs of output node is less than 2";
@ -670,7 +670,8 @@ void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
MS_LOG(INFO) << "Attached for side-effect nodes, depend_node: " << depend_node->DebugString()
<< ", state: " << state->DebugString(recursive_level);
func_graph_->set_output(depend_node, true);
if (return_node && return_node->debug_info()) {
// Update new return node's debug_info with old one.
if (return_node != nullptr && return_node->debug_info()) {
auto new_return = func_graph_->get_return();
new_return->set_debug_info(return_node->debug_info());
}

View File

@ -248,9 +248,7 @@ bool IsDependOfIsolatedNodes(const AnfNodePtr &node) {
return false;
}
auto cnode = dyn_cast<CNode>(node);
if (cnode == nullptr) {
return false;
}
MS_EXCEPTION_IF_NULL(cnode);
auto attr_sort_rhs_first = cnode->GetAttr(kAttrTopoSortRhsFirst);
auto sort_rhs_first =
attr_sort_rhs_first != nullptr && attr_sort_rhs_first->isa<BoolImm>() && GetValue<bool>(attr_sort_rhs_first);

View File

@ -79,45 +79,6 @@ bool CheckIfAlwaysEval(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg)
return false;
}
void BaseFuncGraphEvaluator::CollectSideEffectNodes(const AnfNodePtr &node,
std::vector<AnfNodePtr> *side_effect_nodes) {
const auto &primitive = GetCNodePrimitiveWithoutDoSignature(node);
if (primitive != nullptr) {
auto effect_info = GetPrimEffectInfo(primitive);
if (effect_info.memory || effect_info.io) {
MS_LOG(DEBUG) << "Side Effect Primitive CNode: " << node->DebugString();
(void)side_effect_nodes->emplace_back(node);
}
}
}
void BaseFuncGraphEvaluator::CheckSideEffectNodes(const AbstractBasePtr &abstract,
const std::vector<AnfNodePtr> &side_effect_nodes) const {
if (!side_effect_nodes.empty()) {
ValuePtr val = abstract->BuildValue();
if (!val->isa<AnyValue>()) {
std::stringstream ss;
ss << "Side Effect Invalid: Found unsupported syntax in graph mode, those side effect codes would be ignored:\n";
ss << "-----\n";
size_t num = 1;
for (auto &side_effect_node : side_effect_nodes) {
ss << "# No. " << num << ":\n" << trace::GetDebugInfo(side_effect_node->debug_info()) << "\n";
++num;
}
ss << "-----\n";
// All nodes in side_effect_nodes are CNode, so its func_graph() must not be null.
auto fg = side_effect_nodes[0]->func_graph();
MS_EXCEPTION_IF_NULL(fg);
ss << "\nIf a function return a const value or inferred const value, the side effect node would be ignored.\n"
<< "So the codes may not run as the user's expectation, please fix it.\n\n"
<< "In this case, the const value '" << val->ToString() << "' returns: \n"
<< trace::GetDebugInfo(fg->debug_info()) << "\nFor more information about this issue, please refer to "
<< "https://www.mindspore.cn/search?inputValue=Side%20Effect%20Invalid\n";
MS_LOG(EXCEPTION) << ss.str();
}
}
}
void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame,
const StackFramePtr &new_stack_frame) {
MS_EXCEPTION_IF_NULL(current_stack_frame);
@ -243,12 +204,8 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEngine
abstract = node_eval_result->abstract();
MS_EXCEPTION_IF_NULL(abstract);
MS_LOG(DEBUG) << GetInferThread() << "Eval ( " << node_conf->ToString() << ") = " << abstract->ToString();
// Check if contains side effect operations.
CollectSideEffectNodes(node, &side_effect_nodes);
}
MS_EXCEPTION_IF_NULL(abstract);
CheckSideEffectNodes(abstract, side_effect_nodes);
return abstract;
}
@ -337,6 +294,8 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
}
MS_LOG(DEBUG) << GetInferThread() << "} //" << fg->ToString() << " = " << abstract->ToString();
SyncFuncGraphIsolatedSideEffectFlag(fg);
trace::TraceGraphEvalLeave(context);
// Decrease the func graph call depth.
DecreaseFunctionCallDepth();

View File

@ -221,8 +221,7 @@ class BaseFuncGraphEvaluator : public Evaluator {
return always_eval_flags_.back();
}
void CollectSideEffectNodes(const AnfNodePtr &node, std::vector<AnfNodePtr> *side_effect_nodes);
void CheckSideEffectNodes(const AbstractBasePtr &abstract, const std::vector<AnfNodePtr> &side_effect_nodes) const;
virtual void SyncFuncGraphIsolatedSideEffectFlag(const FuncGraphPtr &func_graph) = 0;
protected:
AnalysisContextPtr parent_context_;
@ -256,6 +255,12 @@ class FuncGraphEvaluator : public BaseFuncGraphEvaluator {
AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_abs_list) override;
std::string ToString() const override { return identifier_ + "_" + func_graph_->ToString(); }
void SyncFuncGraphIsolatedSideEffectFlag(const FuncGraphPtr &func_graph) override {
if (func_graph->has_isolated_side_effect_node()) {
func_graph_->set_has_isolated_side_effect_node(true);
}
}
private:
FuncGraphPtr func_graph_;
FuncGraphCacheMap func_graph_cache_;
@ -279,6 +284,12 @@ class MetaFuncGraphEvaluator : public BaseFuncGraphEvaluator {
}
std::string ToString() const override { return identifier_ + "_" + meta_func_graph_->ToString(); }
void SyncFuncGraphIsolatedSideEffectFlag(const FuncGraphPtr &func_graph) override {
if (func_graph->has_isolated_side_effect_node()) {
meta_func_graph_->set_has_isolated_side_effect_node(true);
}
}
private:
MetaFuncGraphPtr meta_func_graph_;
FuncGraphCacheMap func_graph_cache_;

View File

@ -1366,7 +1366,8 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &, con
EvalResultPtr GetEvaluatedValueForNameSpace(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list,
const AnfNodeConfigPtr &out_conf) {
// args_spec_list: same as StaticGetter
if (args_spec_list.size() < 2) {
constexpr size_t args_min_size = 2;
if (args_spec_list.size() < args_min_size) {
MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2";
}
MS_EXCEPTION_IF_NULL(out_conf);

View File

@ -783,9 +783,21 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
MS_LOG(EXCEPTION) << "Fail to get input's abstract value, with input config: " << input_conf->ToString()
<< ", in old node: " << c_old->DebugString();
}
// First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
// can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
AnfNodePtr replace_node = BuildPossibleValueNode(node_input, abs, attrs, node);
bool ignore_build_value = false;
AnfNodePtr replace_node = nullptr;
if (specializer_->engine()->check_isolated_side_effect()) {
auto cnode_input = dyn_cast<CNode>(node_input);
ignore_build_value = (cnode_input != nullptr && cnode_input->has_isolated_side_effect_node());
if (ignore_build_value) {
MS_LOG(INFO) << "Don't build value node for CNode which contains isolated side-effect inputs, node: "
<< cnode_input->DebugString() << ", flag: " << cnode_input->has_isolated_side_effect_node();
}
}
if (!ignore_build_value) {
// First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
// can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
replace_node = BuildPossibleValueNode(node_input, abs, attrs, node);
}
if (replace_node == nullptr) {
replace_node = BuildReplacedNode(input_conf);
replace_node->set_abstract(abs);

View File

@ -124,6 +124,7 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr
engine->SaveEvalResultInCache(conf, result);
}
fg_evaluator->PushAlwaysEvalFlag(always_eval_flag);
fg_evaluator->SyncFuncGraphIsolatedSideEffectFlag(fg);
// Create a new stack frame and set arguments for it.
auto new_stack_frame = std::make_shared<StackFrame>(fg_evaluator, fg, new_context, parent_context);
new_stack_frame->set_args_abs_list(std::move(args_abs_list));
@ -175,9 +176,6 @@ EvalResultPtr StackFrame::Step(const AnalysisEnginePtr &engine) {
}
MS_LOG(DEBUG) << GetInferThread() << "Eval(" << node_conf->ToString() << ") = "
<< (node_eval_result->abstract() ? node_eval_result->abstract()->ToString() : "Abstract null");
// Check if contains side effect operations.
fg_evaluator->CollectSideEffectNodes(current_node, &side_effect_nodes_);
return node_eval_result;
}
@ -192,6 +190,17 @@ void StackFrame::Back(const AnalysisEnginePtr &engine, const StackFramePtr &last
if (last_stack_frame->func_graph()->stub()) {
result = std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), nullptr);
}
// Check if child func graph contains isolated side-effect.
if (engine->check_isolated_side_effect()) {
if (last_stack_frame->func_graph()->has_isolated_side_effect_node()) {
auto cnode = dyn_cast<CNode>(CurrentNode());
MS_EXCEPTION_IF_NULL(cnode);
cnode->set_has_isolated_side_effect_node(true);
cnode->func_graph()->set_has_isolated_side_effect_node(true);
}
}
// Save func graph eval result for specialize.
auto evaluator = last_stack_frame->evaluator();
MS_EXCEPTION_IF_NULL(evaluator);
@ -210,10 +219,6 @@ void StackFrame::Back(const AnalysisEnginePtr &engine, const StackFramePtr &last
<< ", Save result, NodeConfig: " << node_conf->ToString() << ", result: " << result->abstract().get()
<< "/" << result->abstract()->ToString();
engine->SaveEvalResultInCache(node_conf, result);
fg_evaluator->CheckSideEffectNodes(result->abstract(), last_stack_frame->side_effect_nodes());
last_stack_frame->side_effect_nodes().clear();
// Leave the call CNode.
trace::TraceEvalCNodeLeave();
}

View File

@ -94,8 +94,6 @@ class StackFrame final : public Base {
const AbstractBasePtrList &args_abs_list() const { return args_abs_list_; }
void set_args_abs_list(const AbstractBasePtrList &&args_abs_list) { args_abs_list_ = args_abs_list; }
std::vector<AnfNodePtr> &side_effect_nodes() { return side_effect_nodes_; }
std::string ToString() const override {
MS_EXCEPTION_IF_NULL(func_graph_);
std::ostringstream buffer;
@ -138,7 +136,6 @@ class StackFrame final : public Base {
std::vector<AnfNodePtr> node_slots_;
size_t slot_index_;
bool done_;
std::vector<AnfNodePtr> side_effect_nodes_;
};
} // namespace abstract
} // namespace mindspore

View File

@ -323,12 +323,19 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
MS_EXCEPTION(ValueError) << "This may be not defined, or it can't be a operator. Please check code.";
}
bool contains_isolated_side_effect = false;
ConfigPtrList args_conf_list;
// Ignore the first node which is function name
auto &inputs = cnode->inputs();
for (std::size_t i = 1; i < inputs.size(); i++) {
const AnfNodePtr &node = inputs[i];
args_conf_list.push_back(MakeConfig(node, conf->context(), conf->func_graph()));
if (check_isolated_side_effect()) {
auto input_cnode = dyn_cast<CNode>(node);
if (input_cnode != nullptr) {
contains_isolated_side_effect |= input_cnode->has_isolated_side_effect_node();
}
}
}
std::vector<EvaluatorPtr> evaluators;
@ -348,6 +355,21 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
func->Visit(build_evaluator);
auto eval_result = ExecuteEvaluators(evaluators, conf, args_conf_list);
// Check if func graph contains isolated side-effect, and sync.
if (check_isolated_side_effect()) {
FuncGraphAbstractClosurePtr func_graph_abs = dyn_cast<FuncGraphAbstractClosure>(func);
if (func_graph_abs != nullptr) {
contains_isolated_side_effect |= func_graph_abs->func_graph()->has_isolated_side_effect_node();
}
MetaFuncGraphAbstractClosurePtr meta_func_graph_abs = dyn_cast<MetaFuncGraphAbstractClosure>(func);
if (meta_func_graph_abs != nullptr) {
contains_isolated_side_effect |= meta_func_graph_abs->meta_func_graph()->has_isolated_side_effect_node();
}
if (contains_isolated_side_effect) {
cnode->set_has_isolated_side_effect_node(true);
conf->func_graph()->set_has_isolated_side_effect_node(true);
}
}
return eval_result;
}

View File

@ -291,7 +291,8 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
: prim_constructors_(prim_evaluator_map),
func_graph_manager_(func_graph_manager),
forward_count_(0),
enable_recursive_eval_(common::GetEnv("MS_DEV_RECURSIVE_EVAL") == "1") {}
enable_recursive_eval_(common::GetEnv("MS_DEV_RECURSIVE_EVAL") == "1"),
check_isolated_side_effect_(false) {}
virtual ~AnalysisEngine() = default;
// func_graph: The func_graph to analyze.
@ -345,6 +346,11 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
bool enable_recursive_eval() const { return enable_recursive_eval_; }
static EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node);
bool check_isolated_side_effect() const { return check_isolated_side_effect_; }
void set_check_isolated_side_effect(bool check_isolated_side_effect) {
check_isolated_side_effect_ = check_isolated_side_effect;
}
private:
void SetUndeterminedFlag(const FuncGraphPtr &possible_parent_fg) const;
EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
@ -382,6 +388,8 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
bool enable_recursive_eval_;
bool check_isolated_side_effect_;
#ifdef DEBUG
std::vector<AnfNodePtr> compute_conf_stack_;
#endif

View File

@ -657,12 +657,12 @@ class MS_CORE_API CNode final : public AnfNode, public EffectInfoHolder {
/// \brief Add a node debug info.
///
/// \param node A node debug info of an anf node.
/// \param debug_info A node debug info of an anf node.
void AddFusedDebugInfo(const NodeDebugInfoPtr &debug_info);
/// \brief Add a list of node debug infos.
///
/// \param node A node debug info of an anf node.
/// \param debug_infos A node debug info of an anf node.
void AddFusedDebugInfoList(const std::vector<NodeDebugInfoPtr> &debug_infos);
/// \brief Check whether this node is in ms_function or not in PyNative Mode.
@ -672,9 +672,21 @@ class MS_CORE_API CNode final : public AnfNode, public EffectInfoHolder {
/// \brief Set is_parallel_ for CNode.
///
/// \param[in] is_parallel_ Boolean.
/// \param[in] parallel Boolean.
void set_parallel(bool parallel) { flags_[kIsParallel] = parallel; }
/// \brief Check whether contains a input or indirect input, which is Depend CNode with isolated side-effect node.
///
/// \return True if contains, otherwise false.
bool has_isolated_side_effect_node() const { return has_isolated_side_effect_node_; }
/// \brief Set whether contains a input or indirect input, which is Depend CNode with isolated side-effect node.
///
/// \param[in] has_isolated_side_effect_node Boolean.
void set_has_isolated_side_effect_node(bool has_isolated_side_effect_node) {
has_isolated_side_effect_node_ = has_isolated_side_effect_node;
}
private:
static constexpr size_t kStopGradient = 0;
static constexpr size_t kInForwardFlag = 1;
@ -693,6 +705,9 @@ class MS_CORE_API CNode final : public AnfNode, public EffectInfoHolder {
mindspore::HashMap<std::string, ValuePtr> primal_attrs_;
NodeDebugInfoSet primal_debug_infos_;
NodeDebugInfoSet fused_debug_infos_;
// If the inputs or their inputs contain Depend CNode with isolated side-effect node.
bool has_isolated_side_effect_node_{false};
};
// ANode represents the atomic node. It's derived Parameter and ValueNode.

View File

@ -65,9 +65,18 @@ class FuncGraphBase : public Value {
// Clear the member of FuncGraph to break loop
virtual void DoBreakLoop() = 0;
bool has_isolated_side_effect_node() const { return has_isolated_side_effect_node_; }
void set_has_isolated_side_effect_node(bool has_isolated_side_effect_node) {
has_isolated_side_effect_node_ = has_isolated_side_effect_node;
}
protected:
friend FuncGraphLoopBreaker;
bool reg_flg = false;
bool reg_flg{false};
private:
// If the nodes or their callee's nodes contain Depend CNode with isolated side-effect node.
bool has_isolated_side_effect_node_{false};
};
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CORE_IR_FUNC_GRAPH_BASE_H_

View File

@ -55,7 +55,6 @@ def _run_opt_with_one_number(eps, clip_threshold, beta1, beta2t, weight_decay, s
compression, use_first_moment, weight_decay_flag, learning_rate,
grad, param, exp_avg, exp_avg_sq_row, exp_avg_sq_col, exp_avg_sq):
"""Apply ada factor optimizer to the weight parameter using Tensor."""
success = True
grad_dtype = F.dtype(grad)
grad_shape = F.shape(grad)
@ -114,17 +113,17 @@ def _run_opt_with_one_number(eps, clip_threshold, beta1, beta2t, weight_decay, s
p_data_fp32_coff = p_data_fp32 * -weight_decay * learning_rate_update
p_data_fp32 = P.Add()(p_data_fp32, p_data_fp32_coff)
p_data_fp32 = P.Sub()(p_data_fp32, update)
return F.depend(success, P.Assign()(param, F.cast(p_data_fp32, F.dtype(param))))
P.Assign()(param, F.cast(p_data_fp32, F.dtype(param)))
return True
@_adafactor_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor")
def _run_fused_ada_factor(fused_ada_factor, eps, clip_threshold, beta1, beta2t, weight_decay, learning_rate,
grad, param, exp_avg, exp_avg_sq_row, exp_avg_sq_col, exp_avg_sq):
success = True
ret = fused_ada_factor(eps, clip_threshold, beta1, beta2t, weight_decay, learning_rate,
grad, param, exp_avg, exp_avg_sq_row, exp_avg_sq_col, exp_avg_sq)
return F.depend(success, ret)
fused_ada_factor(eps, clip_threshold, beta1, beta2t, weight_decay, learning_rate,
grad, param, exp_avg, exp_avg_sq_row, exp_avg_sq_col, exp_avg_sq)
return True
def trans_to_tensor(param, is_tuple=False, fp32=True):

View File

@ -293,14 +293,12 @@ def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, grad
def _run_fused_adam_weight_decay_opt(opt, beta1, beta2, eps, lr, weight_decay, param, moment1, moment2, gradient,
decay_flags, optim_filter):
"""Apply FusedAdamWeightDecay optimizer to the weight parameter using Tensor."""
success = True
if optim_filter:
if decay_flags:
out = opt(param, moment1, moment2, lr, beta1, beta2, eps, weight_decay, P.Cast()(gradient, F.dtype(param)))
opt(param, moment1, moment2, lr, beta1, beta2, eps, weight_decay, P.Cast()(gradient, F.dtype(param)))
else:
out = opt(param, moment1, moment2, lr, beta1, beta2, eps, 0.0, P.Cast()(gradient, F.dtype(param)))
return F.depend(success, out)
return success
opt(param, moment1, moment2, lr, beta1, beta2, eps, 0.0, P.Cast()(gradient, F.dtype(param)))
return True
def _check_param_value(beta1, beta2, eps, prim_name):

View File

@ -786,9 +786,8 @@ class ApplyProximalAdagradConstantNet(nn.Cell):
self.const = Tensor(9999, mstype.float32)
def construct(self, lr, l1, l2, grad, indices):
optimizer = self.sparse_apply_proximal_adagrad(
self.var, self.accum, lr, l1, l2, grad, indices)
return self.depend(self.const, optimizer)
self.sparse_apply_proximal_adagrad(self.var, self.accum, lr, l1, l2, grad, indices)
return self.const
@pytest.mark.level1

View File

@ -18,7 +18,6 @@ import time
import numpy as np
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore import ParameterTuple
from mindspore.train.callback import Callback
@ -147,4 +146,5 @@ class TrainOneStepCell(nn.Cell):
grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens)
if self.reduce_flag:
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))
self.optimizer(grads)
return loss

View File

@ -18,7 +18,6 @@ import os
import numpy as np
from sklearn.metrics import roc_auc_score
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.nn import Dropout
@ -321,7 +320,8 @@ class TrainStepWrap(nn.Cell):
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))
self.optimizer(grads)
return loss
class PredictWithSigmoid(nn.Cell):

View File

@ -17,7 +17,6 @@ import numpy as np
from mindspore import nn
from mindspore import Parameter, ParameterTuple
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.nn.optim import Adam, FTRL
@ -312,8 +311,9 @@ class TrainStepWrap(nn.Cell):
if self.reducer_flag:
grads_w = self.grad_reducer_w(grads_w)
grads_d = self.grad_reducer_d(grads_d)
return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d,
self.optimizer_d(grads_d))
self.optimizer_w(grads_w)
self.optimizer_d(grads_d)
return loss_w, loss_d
class PredictWithSigmoid(nn.Cell):

View File

@ -671,7 +671,8 @@ class TrainingWrapper(nn.Cell):
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))
self.optimizer(grads)
return loss
class YoloBoxScores(nn.Cell):

View File

@ -488,8 +488,8 @@ update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
@update_accu_grads.register("Tensor", "Tensor")
def _update_accu_grads(accu_grad, grad):
succ = True
return F.depend(succ, F.assign(accu_grad, cast(grad, mstype.float32)))
F.assign(accu_grad, cast(grad, mstype.float32))
return True
accumulate_accu_grads = C.MultitypeFuncGraph("accumulate_accu_grads")
@ -497,8 +497,8 @@ accumulate_accu_grads = C.MultitypeFuncGraph("accumulate_accu_grads")
@accumulate_accu_grads.register("Tensor", "Tensor")
def _accumulate_accu_grads(accu_grad, grad):
succ = True
return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
F.assign_add(accu_grad, cast(grad, mstype.float32))
return True
zeroslike = P.ZerosLike()
@ -507,8 +507,8 @@ reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads")
@reset_accu_grads.register("Tensor")
def _reset_accu_grads(accu_grad):
succ = True
return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
F.assign(accu_grad, zeroslike(accu_grad))
return True
class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell):

View File

@ -89,7 +89,8 @@ class TrainForwardBackward(Cell):
loss = self.network(*inputs)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*inputs, sens)
return F.depend(loss, self.hyper_map(F.partial(_sum_op), self.grad_sum, grads))
self.hyper_map(F.partial(_sum_op), self.grad_sum, grads)
return loss
class TrainOptim(Cell):

View File

@ -34,9 +34,8 @@ def _adam_opt(opt, beta1, beta2, eps, lr, weight_decay, param, m, v, gradient):
"""
Update parameters by AdamWeightDecay op.
"""
success = True
next_param = opt(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient)
return F.depend(success, next_param)
opt(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient)
return True
def _check_param_value(beta1, beta2, eps, prim_name):

View File

@ -24,7 +24,7 @@ class NetValueNodeWithDepend(nn.Cell):
return output
@pytest.mark.level2
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_value_node_with_depend():

View File

@ -22,7 +22,6 @@ from mindspore import context, Model
from mindspore.common.api import _cell_graph_executor
from mindspore.nn.optim import Adam, FTRL
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.parallel._utils import _reset_op_id as reset_op_id
@ -99,8 +98,10 @@ class TrainStepWarp(nn.Cell):
sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens)
sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens)
grads_w = self.grad_w(self.loss_net_w, weights_w)(x, sens_w)
self.optimizer_w(grads_w)
grads_d = self.grad_d(self.loss_net_d, weights_d)(x, sens_d)
return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d, self.optimizer_d(grads_d))
self.optimizer_d(grads_d)
return loss_w, loss_d
def test_double_subgraphs():

View File

@ -21,7 +21,6 @@ from mindspore import context
from mindspore.common.api import _cell_graph_executor
from mindspore.nn.optim import Adam, FTRL
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.parallel._utils import _reset_op_id as reset_op_id
@ -117,8 +116,10 @@ class TrainStepWarp(nn.Cell):
sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens)
sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens)
grads_w = self.grad_w(self.loss_net_w, weights_w)(x, sens_w)
self.optimizer_w(grads_w)
grads_d = self.grad_d(self.loss_net_d, weights_d)(x, sens_d)
return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d, self.optimizer_d(grads_d))
self.optimizer_d(grads_d)
return loss_w, loss_d
def test_double_subgraphs():

View File

@ -20,7 +20,6 @@ from mindspore import nn
from mindspore.common.api import _cell_graph_executor
from mindspore.nn.optim import Adam, FTRL
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
@ -98,8 +97,10 @@ class TrainStepWrap(nn.Cell):
sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens)
sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens)
grads_w = self.grad_w(self.loss_net_w, weights_w)(x, sens_w)
self.optimizer_w(grads_w)
grads_d = self.grad_d(self.loss_net_d, weights_d)(x, sens_d)
return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d, self.optimizer_d(grads_d))
self.optimizer_d(grads_d)
return loss_w, loss_d
def test_two_subgraphs():

View File

@ -371,11 +371,12 @@ def test_partial_parameter():
print(net())
def test_return_const_value_with_side_effect_op():
# Better to run in ST and check stdout.
def test_return_const_value_with_side_effect_io():
"""
Feature: Side effect
Description: Test side effect with returned const value.
Expectation: Throw exception.
Expectation: Not throw exception.
"""
class Demo(nn.Cell):
def construct(self, x):
@ -386,7 +387,26 @@ def test_return_const_value_with_side_effect_op():
x = [[1, 2, 3, 4], [5, 6, 7, 8]]
net = Demo()
with pytest.raises(RuntimeError) as info:
output = net(x)
print(output)
assert "Side Effect Invalid" in str(info.value)
output = net(x)
print(output)
def test_return_const_value_with_side_effect_mem():
"""
Feature: Side effect
Description: Test side effect with returned const value.
Expectation: Not throw exception.
"""
y = Parameter(Tensor([1]))
class Demo(nn.Cell):
def construct(self, x):
P.Assign()(x, Tensor([0]))
P.Assign()(y, Tensor([0]))
return True
x = Parameter(Tensor([1]))
net = Demo()
output = net(x)
print(output)
print(Tensor(x))
print(Tensor(y))