!31745 Enable if parallel call flag by default

Merge pull request !31745 from xychow/enable-if-parallel-call-by-default
This commit is contained in:
i-robot 2022-05-05 09:28:39 +00:00 committed by Gitee
commit 43fd864c10
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
13 changed files with 617 additions and 151 deletions

View File

@ -124,11 +124,33 @@ static inline std::pair<mindspore::HashSet<size_t>, mindspore::HashMap<size_t, s
}
// Erase unused parameters.
std::vector<AnfNodePtr> new_parameters;
const auto &var_arg_node = fg->GetVariableArgParameter();
const auto &kw_arg_node = fg->GetVariableKwargParameter();
const auto &kw_only_args = fg->GetKwOnlyArgsParameters();
for (size_t i = 0; i < parameters.size(); i++) {
const auto &param_i = parameters[i];
if (unused_parameter_indexes.find(i) == unused_parameter_indexes.end()) {
(void)new_parameters.emplace_back(parameters[i]);
(void)new_parameters.emplace_back(param_i);
} else {
MS_LOG(DEBUG) << "Erase parameter:" << parameters[i]->DebugString() << ",index:" << i;
// VarArgs, KwArgs, KwOnlyArgs may not following the index as the Positional Arguments.
if (param_i == var_arg_node) {
fg->set_has_vararg(false);
(void)unused_parameter_indexes.erase(i);
} else if (param_i == kw_arg_node) {
fg->set_has_kwarg(false);
(void)unused_parameter_indexes.erase(i);
} else {
bool is_kw_only_arg = std::any_of(kw_only_args.cbegin(), kw_only_args.cend(),
[param_i](const auto &kw_only_arg) { return kw_only_arg == param_i; });
if (is_kw_only_arg) {
if (fg->kwonlyargs_count() <= 0) {
MS_LOG(EXCEPTION) << "The kw_only_args_count is 0 when a kw_only_arg should be removed";
}
fg->set_kwonlyargs_count(fg->kwonlyargs_count() - 1);
(void)unused_parameter_indexes.erase(i);
}
}
MS_LOG(DEBUG) << "Erase parameter:" << param_i->DebugString() << ", index:" << i;
}
}
manager->SetParameters(fg, new_parameters);
@ -136,7 +158,7 @@ static inline std::pair<mindspore::HashSet<size_t>, mindspore::HashMap<size_t, s
}
// Adjust the call arguments of func graph whose parameter's eliminated.
static inline void AdjustCallerArgs(const CNodePtr &caller,
static inline void AdjustCallerArgs(const FuncGraphPtr &called, const CNodePtr &caller,
const mindspore::HashSet<size_t> &unused_parameter_indexes) {
const FuncGraphManagerPtr &manager = caller->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
@ -148,6 +170,18 @@ static inline void AdjustCallerArgs(const CNodePtr &caller,
MS_LOG(DEBUG) << "Erase arg:" << caller->inputs()[i + 1]->DebugString() << ",index:" << i;
}
}
// Remove any Args which may be packed into VarArgs if VarArgs is not used in called FuncGraph;
// Note: 1. If there is any *args or key=value argument in call site, it will be converted to unpack_call
// CNode. So in this direct call case, all arguments should be plain arguments.
// 2. The arguments in caller may be less than the formal parameters in called as some parameters can have
// default value.
if (!called->has_vararg() &&
caller->inputs().size() > (1 + called->GetPositionalArgsCount() + called->hyper_param_count())) {
size_t start_offset = called->GetPositionalArgsCount() + 1;
size_t end_offset = called->hyper_param_count();
new_args.erase(new_args.begin() + start_offset, new_args.end() - end_offset);
}
TraceGuard trace_guard(std::make_shared<TraceCopy>(caller->debug_info()));
auto new_caller = caller->func_graph()->NewCNode(new_args);
new_caller->set_abstract(caller->abstract());
@ -229,7 +263,7 @@ class ParameterEliminator {
AdjustGetItemCall(caller, only_return_parameter_indexes);
}
// Erase the arguments for eliminated parameters.
AdjustCallerArgs(caller, unused_parameter_indexes);
AdjustCallerArgs(fg, caller, unused_parameter_indexes);
}
changes = true;
}

View File

@ -538,7 +538,7 @@ bool InferenceOptPrepareAction(const ResourcePtr &res) {
}
bool EliminateUnusedParameterAction(const ResourcePtr &res) {
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_IF_PARALLEL_CALL") == "1");
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_IF_PARALLEL_CALL") != "0");
static const auto transform_for_half_unroll_call = (common::GetEnv("MS_DEV_FOR_HALF_UNROLL") == "1");
if (!transform_tail_call_to_parallel_call && !transform_for_half_unroll_call) {
return true;

View File

@ -245,22 +245,61 @@ void Parser::LiftIfBranchGraphFV() {
}
namespace {
bool IsDependOfIsolatedNodes(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
return false;
}
auto cnode = dyn_cast<CNode>(node);
if (cnode == nullptr) {
return false;
}
auto attr_sort_rhs_first = cnode->GetAttr(kAttrTopoSortRhsFirst);
auto sort_rhs_first =
attr_sort_rhs_first != nullptr && attr_sort_rhs_first->isa<BoolImm>() && GetValue<bool>(attr_sort_rhs_first);
return sort_rhs_first;
}
std::pair<CNodePtr, AnfNodePtr> GetRealOutputNodes(const FuncGraphPtr &call_graph) {
auto graph_output = call_graph->output();
if (graph_output == nullptr) {
MS_LOG(EXCEPTION) << "graph_output is null, call_graph: " << call_graph->ToString();
}
auto graph_output_cnode = dyn_cast<CNode>(graph_output);
MS_EXCEPTION_IF_NULL(graph_output_cnode);
// If output cnode is not the tail call but a Depend CNode, keep the dependency node for later use.
AnfNodePtr graph_dependency_node = nullptr;
if (IsDependOfIsolatedNodes(graph_output_cnode)) {
auto graph_real_output_cnode = dyn_cast<CNode>(graph_output_cnode->input(1));
// Get the dependency node;
constexpr auto dependency_node_index = 2;
graph_dependency_node = graph_output_cnode->input(dependency_node_index);
MS_EXCEPTION_IF_NULL(graph_real_output_cnode);
graph_output_cnode = graph_real_output_cnode;
}
return {graph_output_cnode, graph_dependency_node};
}
void TransformParallelCallFormerToMiddle(const FuncGraphPtr &former_call_graph, const FuncGraphPtr &latter_call_graph,
size_t middle_graph_output_cnode_size, bool use_arguments_pack) {
// The 'former_graph_output' is middle graph call.
auto former_graph_output = former_call_graph->output();
MS_EXCEPTION_IF_NULL(former_graph_output);
// The 'former_graph_output' is middle graph call or depend.
const auto &[former_graph_output_cnode, former_graph_dependency_node] = GetRealOutputNodes(former_call_graph);
MS_EXCEPTION_IF_NULL(former_graph_output_cnode);
std::vector<AnfNodePtr> inputs({NewValueNode(latter_call_graph)});
if (use_arguments_pack) {
for (size_t i = 0; i < middle_graph_output_cnode_size - 1; ++i) {
auto getitem_input = former_call_graph->NewCNodeInOrder(
{NewValueNode(prim::kPrimTupleGetItem), former_graph_output, NewValueNode(SizeToLong(i))});
{NewValueNode(prim::kPrimTupleGetItem), former_graph_output_cnode, NewValueNode(SizeToLong(i))});
(void)inputs.emplace_back(getitem_input);
}
} else {
(void)inputs.emplace_back(former_graph_output);
(void)inputs.emplace_back(former_graph_output_cnode);
}
auto new_output = former_call_graph->NewCNodeBefore(former_call_graph->return_node(), std::move(inputs));
if (former_graph_dependency_node != nullptr) {
// Adjust the former funcgraph output with Depend.
new_output = former_call_graph->NewCNodeAfter(
new_output, {NewValueNode(prim::kPrimDepend), new_output, former_graph_dependency_node});
}
former_call_graph->set_output(new_output);
}
@ -287,86 +326,136 @@ bool TransformParallelCallMiddleToLatter(const FuncGraphPtr &middle_call_graph,
return use_arguments_pack;
}
bool IsDependOfIsolatedNodes(const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
return false;
bool IsValueContainScalar(const ValuePtr &value) {
if (value->isa<Scalar>()) {
return true;
}
auto cnode = dyn_cast<CNode>(node);
if (cnode == nullptr) {
return false;
}
auto attr_sort_rhs_first = cnode->GetAttr(kAttrTopoSortRhsFirst);
auto sort_rhs_first =
attr_sort_rhs_first != nullptr && attr_sort_rhs_first->isa<BoolImm>() && GetValue<bool>(attr_sort_rhs_first);
return sort_rhs_first;
return false;
}
std::pair<CNodePtr, AnfNodePtr> GetRealMiddleOutputNodes(const FuncGraphPtr &middle_call_graph) {
auto middle_graph_output = middle_call_graph->output();
if (middle_graph_output == nullptr) {
MS_LOG(EXCEPTION) << "middle_graph_output is null, middle_call_graph: " << middle_call_graph->ToString();
bool IsOutputContainScalar(const CNodePtr &output_cnode) {
return std::any_of(output_cnode->inputs().cbegin() + 1, output_cnode->inputs().end(), [](const AnfNodePtr &node) {
if (node->isa<ValueNode>()) {
auto value_node = node->cast<ValueNodePtr>();
return IsValueContainScalar(value_node->value());
}
return false;
});
}
bool CheckMiddleGraphOutputContainScalar(
const std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> &parallel_call_vec) {
std::vector<bool> contains_scalar;
for (auto &call_graphs_pair : parallel_call_vec) {
MS_EXCEPTION_IF_NULL(call_graphs_pair.second);
auto middle_call_graph = call_graphs_pair.second->func_graph();
constexpr auto recur_2 = 2;
const auto &middle_graph_output_pair = GetRealOutputNodes(middle_call_graph);
const auto middle_graph_output_cnode = middle_graph_output_pair.first;
auto middle_graph_output_cnode_size = middle_graph_output_cnode->inputs().size();
if (middle_graph_output_cnode_size <= 1) {
MS_LOG(DEBUG) << "CNode's inputs size should exceed 1, " << middle_graph_output_cnode->DebugString(recur_2);
return false;
}
static const auto transform_if_const_scalar = (common::GetEnv("MS_DEV_IF_PARALLEL_CALL") == "2");
if (!transform_if_const_scalar && IsOutputContainScalar(middle_graph_output_cnode)) {
MS_LOG(DEBUG) << "CNode's inputs contain const scalar, " << middle_graph_output_cnode->DebugString(recur_2);
contains_scalar.push_back(true);
} else {
contains_scalar.push_back(false);
}
}
auto middle_graph_output_cnode = dyn_cast<CNode>(middle_graph_output);
MS_EXCEPTION_IF_NULL(middle_graph_output_cnode);
// If latter_call_graph is not the tail call in middle funcgraph, keep the dependency node for later use.
AnfNodePtr middle_graph_dependency_node = nullptr;
if (IsDependOfIsolatedNodes(middle_graph_output_cnode)) {
auto middle_graph_real_output_cnode = dyn_cast<CNode>(middle_graph_output_cnode->input(1));
// Get the dependency node;
constexpr auto dependency_node_index = 2;
middle_graph_dependency_node = middle_graph_output_cnode->input(dependency_node_index);
MS_EXCEPTION_IF_NULL(middle_graph_real_output_cnode);
middle_graph_output_cnode = middle_graph_real_output_cnode;
return std::all_of(contains_scalar.cbegin(), contains_scalar.cend(), [](bool is_scalar) { return is_scalar; });
}
bool CheckMiddleGraphOutputPyInterpret(
const std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> &parallel_call_vec) {
bool contain_py_interpret = false;
for (auto &call_graphs_pair : parallel_call_vec) {
MS_EXCEPTION_IF_NULL(call_graphs_pair.second);
auto middle_call_graph = call_graphs_pair.second->func_graph();
constexpr auto recur_2 = 2;
const auto &middle_graph_output_pair = GetRealOutputNodes(middle_call_graph);
const auto middle_graph_output_cnode = middle_graph_output_pair.first;
auto middle_graph_output_cnode_size = middle_graph_output_cnode->inputs().size();
if (middle_graph_output_cnode_size <= 1) {
MS_LOG(DEBUG) << "CNode's inputs size should exceed 1, " << middle_graph_output_cnode->DebugString(recur_2);
return false;
}
contain_py_interpret |=
std::any_of(middle_graph_output_cnode->inputs().cbegin() + 1, middle_graph_output_cnode->inputs().cend(),
[](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimPyInterpret); });
if (contain_py_interpret) {
return true;
}
}
return {middle_graph_output_cnode, middle_graph_dependency_node};
return false;
}
} // namespace
// Transform tail call to parallel call.
void Parser::TransformParallelCall() {
std::unordered_set<FuncGraphPtr> latter_call_graphs_set;
for (auto &call_graphs_pair : parallel_call_graphs_) {
MS_EXCEPTION_IF_NULL(call_graphs_pair.first);
auto former_call_graph = call_graphs_pair.first->func_graph();
MS_EXCEPTION_IF_NULL(call_graphs_pair.second);
auto middle_call_graph = call_graphs_pair.second->func_graph();
// Transform the call of {middle_graph -> latter_graph}.
auto middle_graph_return = middle_call_graph->get_return();
if (middle_graph_return == nullptr) {
MS_LOG(INFO) << "middle_graph_return is null, middle_call_graph: " << middle_call_graph->ToString();
for (auto &parallel_call_vec : parallel_call_graphs_) {
bool all_middle_graphs_output_scalar = CheckMiddleGraphOutputContainScalar(parallel_call_vec);
if (all_middle_graphs_output_scalar) {
MS_LOG(DEBUG) << "All middle func graph's output contain const scalar, cannot transform to Parallel_If.";
continue;
}
constexpr auto recur_3 = 3;
MS_LOG(DEBUG) << "Tail call graphs return: {former: " << former_call_graph->get_return()->DebugString(recur_3)
<< ", middle: " << middle_call_graph->get_return()->DebugString(recur_3) << "}";
const auto &[middle_graph_output_cnode, middle_graph_dependency_node] = GetRealMiddleOutputNodes(middle_call_graph);
auto middle_graph_output_cnode_size = middle_graph_output_cnode->inputs().size();
if (middle_graph_output_cnode_size <= 1) {
// After Join, Value in Abstract of PyInterpret CNode will be kAnyValue, it cannot be PyInterpreted again, so
// ignore the transformation.
bool is_middle_graphs_output_py_interpret = CheckMiddleGraphOutputPyInterpret(parallel_call_vec);
if (is_middle_graphs_output_py_interpret) {
MS_LOG(DEBUG) << "Middle func graph's output contain PyInterpret CNode, cannot transform to Parallel_If.";
continue;
}
for (auto &call_graphs_pair : parallel_call_vec) {
MS_EXCEPTION_IF_NULL(call_graphs_pair.first);
auto former_call_graph = call_graphs_pair.first->func_graph();
MS_EXCEPTION_IF_NULL(call_graphs_pair.second);
auto middle_call_graph = call_graphs_pair.second->func_graph();
// Transform the call of {middle_graph -> latter_graph}.
auto middle_graph_return = middle_call_graph->get_return();
if (middle_graph_return == nullptr) {
MS_LOG(INFO) << "middle_graph_return is null, middle_call_graph: " << middle_call_graph->ToString();
continue;
}
constexpr auto recur_3 = 3;
constexpr auto recur_2 = 2;
MS_LOG(DEBUG) << "CNode's inputs size should exceed 1, " << middle_graph_output_cnode->DebugString(recur_2);
continue;
}
auto latter_graph_node = middle_graph_output_cnode->input(0);
bool use_arguments_pack = TransformParallelCallMiddleToLatter(
middle_call_graph, middle_graph_output_cnode, middle_graph_dependency_node, middle_graph_output_cnode_size);
MS_LOG(DEBUG) << "Tail call graphs return: {former: " << former_call_graph->get_return()->DebugString(recur_3)
<< ", middle: " << middle_call_graph->get_return()->DebugString(recur_3) << "}";
const auto &[middle_graph_output_cnode, middle_graph_dependency_node] = GetRealOutputNodes(middle_call_graph);
auto middle_graph_output_cnode_size = middle_graph_output_cnode->inputs().size();
if (middle_graph_output_cnode_size <= 1) {
MS_LOG(DEBUG) << "CNode's inputs size should exceed 1, " << middle_graph_output_cnode->DebugString(recur_2);
continue;
}
// Transform the call of {former_graph -> middle_graph}.
auto latter_call_graph = GetValueNode<FuncGraphPtr>(latter_graph_node);
if (latter_call_graph == nullptr) {
constexpr auto recur_2 = 2;
MS_LOG(ERROR) << "The latter graph node is not FuncGraph, " << latter_graph_node->DebugString(recur_2);
continue;
}
if (latter_call_graphs_set.find(latter_call_graph) != latter_call_graphs_set.end()) {
MS_LOG(DEBUG) << "The latter graph is handled before, " << latter_call_graph->ToString();
continue;
}
(void)latter_call_graphs_set.emplace(latter_call_graph);
TransformParallelCallFormerToMiddle(former_call_graph, latter_call_graph, middle_graph_output_cnode_size,
use_arguments_pack);
auto latter_graph_node = middle_graph_output_cnode->input(0);
bool use_arguments_pack = TransformParallelCallMiddleToLatter(
middle_call_graph, middle_graph_output_cnode, middle_graph_dependency_node, middle_graph_output_cnode_size);
MS_LOG(DEBUG) << "Parallel call graphs return: {former: " << former_call_graph->get_return()->DebugString(recur_3)
<< ", middle: " << middle_call_graph->get_return()->DebugString(recur_3) << "}";
// Transform the call of {former_graph -> middle_graph}.
auto latter_call_graph = GetValueNode<FuncGraphPtr>(latter_graph_node);
if (latter_call_graph == nullptr) {
MS_LOG(ERROR) << "The latter graph node is not FuncGraph, " << latter_graph_node->DebugString(recur_2);
continue;
}
if (latter_call_graphs_set.find(latter_call_graph) != latter_call_graphs_set.end()) {
MS_LOG(DEBUG) << "The latter graph is handled before, " << latter_call_graph->ToString();
continue;
}
(void)latter_call_graphs_set.emplace(latter_call_graph);
TransformParallelCallFormerToMiddle(former_call_graph, latter_call_graph, middle_graph_output_cnode_size,
use_arguments_pack);
MS_LOG(DEBUG) << "Parallel call graphs return: {former: " << former_call_graph->get_return()->DebugString(recur_3)
<< ", middle: " << middle_call_graph->get_return()->DebugString(recur_3) << "}";
}
}
// Lift inner, then lift outer.
@ -1708,15 +1797,16 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
<< ", false_end: " << false_end->ToString() << ", after_block: " << after_block->ToString();
(void)ignored_if_latter_call_graphs_.insert(after_block);
}
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_IF_PARALLEL_CALL") == "1");
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_IF_PARALLEL_CALL") != "0");
if (transform_tail_call_to_parallel_call && true_branch_graphs.second != nullptr &&
false_branch_graphs.second != nullptr) {
true_branch_graphs.first = block;
(void)parallel_call_graphs_.emplace_back(true_branch_graphs);
MS_LOG(DEBUG) << "Record tail call graphs, true: {former: " << true_branch_graphs.first->func_graph()->ToString()
<< ", middle: " << true_branch_graphs.second->func_graph()->ToString() << "}";
false_branch_graphs.first = block;
(void)parallel_call_graphs_.emplace_back(false_branch_graphs);
std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> branch_graphs_vec{true_branch_graphs,
false_branch_graphs};
(void)parallel_call_graphs_.emplace_back(branch_graphs_vec);
MS_LOG(DEBUG) << "Record tail call graphs, false: {former: " << false_branch_graphs.first->func_graph()->ToString()
<< ", middle: " << false_branch_graphs.second->func_graph()->ToString() << "}";
}
@ -1987,7 +2077,8 @@ FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py:
std::pair<FunctionBlockPtr, FunctionBlockPtr> loop_graphs;
loop_graphs.first = body_block;
loop_graphs.second = after_body_block;
(void)parallel_call_graphs_.emplace_back(loop_graphs);
std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> loop_graphs_vec{loop_graphs};
(void)parallel_call_graphs_.emplace_back(loop_graphs_vec);
MS_LOG(DEBUG) << "Record tail call graphs, loop: {former: " << loop_graphs.first->func_graph()->ToString()
<< ", middle: " << loop_graphs.second->func_graph()->ToString() << "}";
// Record the rolled body function, for later lifting operation.

View File

@ -339,7 +339,7 @@ class Parser {
// The func graphs to transform tail call ir to independent call ir.
// Contains: {former_graph, middle_graph}, latter_graph is no need.
std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> parallel_call_graphs_;
std::vector<std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>>> parallel_call_graphs_;
// The true branch and false branch call info. of if statement.
std::vector<std::tuple<CNodePtr, FunctionBlockPtr, FunctionBlockPtr>> if_branch_calls_;
// The rolled_body callers info. for later lifting operation.

View File

@ -518,6 +518,8 @@ OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &) {
OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &) {
opt::OptPassConfig control_group = opt::OptPassConfig(opt::irpass::ConvertSwitchReplacement());
OptPassGroupMap map({
// After CleanAfterOptA, it may need renormalize to eliminate unused elements in Tuple.
{"renormalize", opt::OptPassConfig::Renormalize()},
{"control_group", control_group},
{"renormalize", opt::OptPassConfig::Renormalize()},
});

View File

@ -664,6 +664,7 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const Fu
if (possible_parent_fg != nullptr) {
possible_parent_fg->set_flag(kFuncGraphFlagUndetermined, true);
MS_LOG(DEBUG) << "Set graph undetermined: " << possible_parent_fg->ToString();
return;
}
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
if (fg_eval == nullptr) {
@ -671,14 +672,7 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const Fu
}
auto fg = fg_eval->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto fg_parent = fg->parent();
if (fg_parent != nullptr) {
fg_parent->set_flag(kFuncGraphFlagUndetermined, true);
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString() << " for fg: " << fg->ToString();
return;
} else {
MS_LOG(DEBUG) << "cannot find parent for fg: " << fg->ToString();
}
MS_LOG(EXCEPTION) << "cannot set Undetermined flag for fg: " << fg->ToString();
}
EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators,

View File

@ -40,7 +40,7 @@ FuncGraph::FuncGraph(GraphDebugInfoPtr &&debug_info)
has_vararg_(false),
has_kwarg_(false),
exist_multi_target_(false),
kwonlyargs_count_(0),
kw_only_args_count_(0),
hyper_param_count_(0),
is_generated_(false),
is_bprop_(false),
@ -569,21 +569,20 @@ AnfNodePtr FuncGraph::GetVariableArgParameter() {
return nullptr;
}
// one vararg + kwarg so the min param num is 2;
constexpr size_t min_param_num = 2;
size_t min_param_num = 1;
if (has_kwarg_) {
if (parameters_.size() < hyper_param_count_ + min_param_num) {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
<< hyper_param_count_ << ", parameters is less than 2 + hyper_param_count";
}
return parameters_[(parameters_.size() - hyper_param_count_) - min_param_num];
min_param_num += 1;
}
min_param_num += kw_only_args_count_;
min_param_num += hyper_param_count_;
if (parameters_.size() < hyper_param_count_ + 1) {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
<< hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
if (parameters_.size() < min_param_num) {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size()
<< " which less than the sum of following: hyper_param_count: " << hyper_param_count_
<< ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_
<< ", kw_only_args_count_: " << kw_only_args_count_;
}
return parameters_[(parameters_.size() - hyper_param_count_) - 1];
return parameters_[parameters_.size() - min_param_num + kw_only_args_count_];
}
std::string FuncGraph::GetVariableArgName() {
@ -591,24 +590,9 @@ std::string FuncGraph::GetVariableArgName() {
return "";
}
// one vararg + kwarg so the min param num is 2;
constexpr size_t min_param_num = 2;
if (has_kwarg_) {
if (parameters_.size() < hyper_param_count_ + min_param_num) {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
<< hyper_param_count_ << ", parameters is less than 2 + hyper_param_count";
}
const auto &parameter =
parameters_[(parameters_.size() - hyper_param_count_) - min_param_num]->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter);
return parameter->name();
}
if (parameters_.size() < hyper_param_count_ + 1) {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
<< hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
}
const auto &parameter = parameters_[(parameters_.size() - hyper_param_count_) - 1]->cast<ParameterPtr>();
const auto &param_node = GetVariableArgParameter();
MS_EXCEPTION_IF_NULL(param_node);
const auto &parameter = param_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter);
return parameter->name();
}
@ -637,6 +621,37 @@ std::string FuncGraph::GetVariableKwargName() {
return "";
}
AnfNodePtrList FuncGraph::GetKwOnlyArgsParameters() {
AnfNodePtrList kw_only_args;
if (kw_only_args_count_ == 0) {
return kw_only_args;
}
size_t min_param_num = 0;
size_t varargs_kwargs_num = 0;
if (has_vararg_) {
min_param_num += 1;
varargs_kwargs_num += 1;
}
if (has_kwarg_) {
min_param_num += 1;
varargs_kwargs_num += 1;
}
min_param_num += kw_only_args_count_;
min_param_num += hyper_param_count_;
if (parameters_.size() < min_param_num) {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size()
<< " which less than the sum of following: hyper_param_count: " << hyper_param_count_
<< ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_
<< ", kw_only_args_count: " << kw_only_args_count_;
}
size_t kw_only_args_start_offset = parameters_.size() - min_param_num;
std::copy(parameters_.cbegin() + kw_only_args_start_offset,
parameters_.cend() - hyper_param_count_ - varargs_kwargs_num, std::back_inserter(kw_only_args));
return kw_only_args;
}
int FuncGraph::GetPositionalArgsCount() const {
int count = SizeToInt(parameters_.size());
if (has_kwarg_) {
@ -645,7 +660,7 @@ int FuncGraph::GetPositionalArgsCount() const {
if (has_vararg_) {
count--;
}
return (count - kwonlyargs_count_) - SizeToInt(hyper_param_count_);
return (count - kw_only_args_count_) - SizeToInt(hyper_param_count_);
}
AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) {

View File

@ -165,14 +165,16 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
std::map<std::string, AnfNodePtr> &parameter_default_value() { return parameter_default_value_; }
void set_has_vararg(bool has_) { has_vararg_ = has_; }
bool has_vararg() const { return has_vararg_; }
// Parameters are ordered as: Positional Parameters, Kwonlyargs, *Varargs, **Kwargs, HyperParam;
AnfNodePtr GetVariableArgParameter();
std::string GetVariableArgName();
void set_has_kwarg(bool has_) { has_kwarg_ = has_; }
bool has_kwarg() const { return has_kwarg_; }
void set_kwonlyargs_count(int count) { kwonlyargs_count_ = count; }
int kwonlyargs_count() const { return kwonlyargs_count_; }
void set_kwonlyargs_count(int count) { kw_only_args_count_ = count; }
int kwonlyargs_count() const { return kw_only_args_count_; }
AnfNodePtr GetVariableKwargParameter();
std::string GetVariableKwargName();
AnfNodePtrList GetKwOnlyArgsParameters();
void set_hyper_param_count(size_t count) { hyper_param_count_ = count; }
size_t hyper_param_count() const { return hyper_param_count_; }
int GetPositionalArgsCount() const;
@ -412,11 +414,11 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
std::vector<AnfNodePtr> parameters_;
std::vector<AnfNodePtr> paramter_obj_nodes_;
// Whether there is a *args and **kwargs, and count kwonlyargs'number.
// Whether there is a *args and **kwargs, and count kw_only_args'number.
bool has_vararg_;
bool has_kwarg_;
bool exist_multi_target_;
int kwonlyargs_count_;
int kw_only_args_count_;
// Hyper param is placed on the top graph,
// and positioned in the end of the param list, so we record the number to trace the position.
size_t hyper_param_count_;

View File

@ -543,9 +543,6 @@ class _TrainPipelineWithLossScaleCell(TrainOneStepCell):
overflow = cond
if self.loss_scaling_manager is not None:
overflow = self.loss_scaling_manager(self.scale_sense, cond)
if overflow:
succ = False
else:
succ = self.optimizer(grads)
ret = (loss, overflow, scaling_sens)
return F.depend(ret, succ)
if not overflow:
self.optimizer(grads)
return (loss, overflow, scaling_sens)

View File

@ -15,7 +15,7 @@
"""control_ops"""
from ..primitive import PrimitiveWithInfer, prim_attr_register
from ..._checkparam import Validator as validator
from ..._checkparam import Validator as validator, Rel
from ...common import dtype as mstype
@ -74,7 +74,7 @@ class GeSwitch(PrimitiveWithInfer):
raise NotImplementedError
def infer_shape(self, data, pred):
validator.check_equal_int(len(pred), 0, "pred rank", self.name)
validator.check_int_range(len(pred), 0, 1, Rel.INC_BOTH, "pred rank", self.name)
return data, data
def infer_dtype(self, data_type, pred_type):

View File

@ -106,7 +106,7 @@ def test_if_in_for_tensor_4():
@ms_function
def control_flow_for():
x = Tensor(7)
y = Tensor(0)
y = Tensor(0.0)
for _ in range(3):
x = x + y/2
if y < Tensor(10) and x < Tensor(20):
@ -131,7 +131,7 @@ def test_if_in_for_tensor_5():
@ms_function
def control_flow_for():
x = Tensor(7)
y = Tensor(0)
y = Tensor(0.0)
for _ in range(3):
x = x + y/2
if y < Tensor(10):
@ -141,7 +141,7 @@ def test_if_in_for_tensor_5():
y += Tensor(1)
return x + y
res = control_flow_for()
assert res == 60
assert res == 62
@pytest.mark.level0

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import numpy as np
import mindspore as ms
import mindspore.nn as nn
@ -95,6 +96,7 @@ class Net(nn.Cell):
return out
@pytest.mark.skip(reason='Working on it in Parallel')
def test_control_flow():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
context.set_auto_parallel_context(device_num=8, global_rank=0)

View File

@ -29,9 +29,6 @@ grad_all_with_sens = C.GradOperation(sens_param=True)
def test_parser_three_default_mixed_args_subnet():
class SubNetDefaultMixedArgs(Cell):
def __init__(self):
super().__init__()
def construct(self, y, x=3, x1=None, x2=(1, 2)):
if x == 3:
if x1 == None:
@ -66,9 +63,6 @@ def test_net_vararg_kwonlyarg_kwarg():
return c
class SecondNet(Cell):
def __init__(self):
super(SecondNet, self).__init__()
def construct(self, x, y=2, p=5, q=40, *var, key1=1, key2=3, **kwargs):
a = x - y
b = p * q
@ -93,9 +87,6 @@ def test_net_vararg_normal_input():
return c
class SecondNet(Cell):
def __init__(self):
super(SecondNet, self).__init__()
def construct(self, x, y=2, p=5, q=40, *var, key1=1, key2=3, **kwargs):
a = x - y
b = p * q
@ -162,9 +153,6 @@ def test_no_vararg():
return c
class SecondNet(Cell):
def __init__(self):
super(SecondNet, self).__init__()
def construct(self, x, y, *, z=0, r=1):
ret = x + y + z + r
return ret
@ -228,9 +216,6 @@ def test_net_vargs_expand():
return self.grad(self.network)(*inputs)
class AddNet(Cell):
def __init__(self):
super(AddNet, self).__init__()
def construct(self, x, y):
return x + y
@ -289,9 +274,6 @@ def test_mixed_precision_const_parameter():
def test_pass_args_by_key_ward_way():
class KeyWardNet(Cell):
def __init__(self):
super(KeyWardNet, self).__init__()
def construct(self, x, y, z):
return x + y - z
@ -333,3 +315,350 @@ def test_none_input():
x = Tensor(np.array([1, 2, 3, 4]).astype(np.float32).reshape((1, 1, 2, 2,)))
net = Net()
net(x, (4, 4), None, True)
def test_args_kwarg_not_used():
"""
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are varargs and kwargs properly.
Description: Function with unused parameters which are varargs and kwargs.
Expectation: compile success and result == 0
"""
class Net(Cell):
def trivial(self, *args, **kwargs):
return 0
def construct(self, x, y):
ret = self.trivial(x, y)
return ret
net = Net()
x = 1
y = 2
assert net(x, y) == 0
def test_args_kwonlyargs_1_kwarg_not_used():
"""
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are varargs, kwonlyargs and
kwargs properly.
Description: Function with unused parameters which are varargs, 1 kwonlyargs and kwargs.
Expectation: compile success and result == 0
"""
class Net(Cell):
def trivial(self, *args, only1=3, **kwargs):
return 0
def construct(self, x, y):
ret = self.trivial(x, y)
return ret
net = Net()
x = 1
y = 2
assert net(x, y) == 0
def test_args_kwonlyargs_2_kwarg_not_used():
"""
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are varargs, kwonlyargs and
kwargs properly.
Description: Function with unused parameters which are varargs, 2 kwonlyargs and kwargs.
Expectation: compile success and result == 0
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return 0
def construct(self, x, y):
ret = self.trivial(x, y)
return ret
net = Net()
x = 1
y = 2
assert net(x, y) == 0
def test_args_1_used_kwonlyargs_kwarg_not_used():
"""
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are kwonlyargs and
kwargs properly.
Description: Function with unused parameters which are 1 kwonlyargs and kwargs.
Expectation: compile success and result == x
"""
class Net(Cell):
def trivial(self, *args, only1=3, **kwargs):
return args[0]
def construct(self, x, y):
ret = self.trivial(x, y)
return ret
net = Net()
x = 1
y = 2
assert net(x, y) == x
def test_args_2_used_kwonlyargs_kwarg_not_used():
"""
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are kwonlyargs and
kwargs properly.
Description: Function with unused parameters which are 1 kwonlyargs and kwargs.
Expectation: compile success and result == y
"""
class Net(Cell):
def trivial(self, *args, only1=3, **kwargs):
return args[1]
def construct(self, x, y):
ret = self.trivial(x, y)
return ret
net = Net()
x = 1
y = 2
assert net(x, y) == y
def test_kwonlyargs_1_used_args_kwarg_not_used():
"""
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are varargs and
kwargs properly.
Description: Function with unused parameters which are varargs and kwargs.
Expectation: compile success and result == only1
"""
class Net(Cell):
def trivial(self, *args, only1=3, **kwargs):
return only1
def construct(self, x, y):
ret = self.trivial(x, y)
return ret
net = Net()
x = 1
y = 2
assert net(x, y) == 3
def test_kwonlyargs_2_used_args_kwarg_not_used():
"""
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are varargs and
kwargs properly.
Description: Function with unused parameters which are varargs and kwargs.
Expectation: compile success and result == only2
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return only2
def construct(self, x, y):
ret = self.trivial(x, y)
return ret
net = Net()
x = 1
y = 2
assert net(x, y) == 4
def test_kwarg_used_args_kwonlyargs_not_used():
"""
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are varargs and
kwonlyargs properly.
Description: Function with unused parameters which are varargs and kwonlyargs.
Expectation: compile success and result == kw1
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return kwargs["kw1"]
def construct(self, x, y):
ret = self.trivial(x, y, kw1=5)
return ret
net = Net()
x = 1
y = 2
assert net(x, y) == 5
def test_args_1_kwonlyargs_1_used_kwarg_not_used():
"""
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are kwargs properly.
Description: Function with unused parameters which are kwargs.
Expectation: compile success and result == (x, 3)
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return (args[0], only1)
def construct(self, x, y):
ret = self.trivial(x, y)
return ret
net = Net()
x = 1
y = 2
assert net(x, y) == (x, 3)
def test_args_2_kwonlyargs_1_used_kwarg_not_used():
"""
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are kwargs properly.
Description: Function with unused parameters which are kwargs.
Expectation: compile success and result == (x, y, 3)
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return (args[0], args[1], only1)
def construct(self, x, y):
ret = self.trivial(x, y)
return ret
net = Net()
x = 1
y = 2
assert net(x, y) == (x, y, 3)
def test_args_2_kwonlyargs_2_used_kwarg_not_used():
"""
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are kwargs properly.
Description: Function with unused parameters which are kwargs.
Expectation: compile success and result == (x, y, only1, only2)
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return (args[0], args[1], only1, only2)
def construct(self, x, y):
ret = self.trivial(x, y)
return ret
net = Net()
x = 1
y = 2
assert net(x, y) == (x, y, 3, 4)
def test_kwonlyargs_1_kwarg_used_args_not_used():
"""
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are varargs properly.
Description: Function with unused parameters which are varargs.
Expectation: compile success and result == (y, kw1)
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return (only1, kwargs["kw1"])
def construct(self, x, y):
ret = self.trivial(x, y, kw1=5)
return ret
net = Net()
x = 1
y = 2
assert net(x, y) == (3, 5)
def test_kwonlyargs_2_kwarg_used_args_not_used():
"""
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are varargs properly.
Description: Function with unused parameters which are varargs.
Expectation: compile success and result == (only1, only2, kw1)
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return (only1, only2, kwargs["kw1"])
def construct(self, x, y):
ret = self.trivial(x, y, kw1=5)
return ret
net = Net()
x = 1
y = 2
assert net(x, y) == (3, 4, 5)
def test_args_1_kwarg_used_kwonlyargs_not_used():
"""
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are kwonlyargs properly.
Description: Function with unused parameters which are kwonlyargs.
Expectation: compile success and result == (x, kw1)
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return (args[0], kwargs["kw1"])
def construct(self, x, y):
ret = self.trivial(x, y, kw1=5)
return ret
net = Net()
x = 1
y = 2
assert net(x, y) == (x, 5)
def test_args_2_kwarg_used_kwonlyargs_not_used():
"""
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are kwonlyargs properly.
Description: Function with unused parameters which are kwonlyargs.
Expectation: compile success and result == (x, y, kw1)
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return (args[0], args[1], kwargs["kw1"])
def construct(self, x, y):
ret = self.trivial(x, y, kw1=5)
return ret
net = Net()
x = 1
y = 2
assert net(x, y) == (x, y, 5)
def test_args_1_kwonlyargs_1_kwarg_used():
"""
Feature: Eliminate Parameter pass can remove unused parameters and arguments which is kwonlyarg properly.
Description: Function with unused parameters which is kwonlyarg.
Expectation: compile success and result == (x, only1, kw1)
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return (args[0], only1, kwargs["kw1"])
def construct(self, x, y):
ret = self.trivial(x, y, kw1=5)
return ret
net = Net()
x = 1
y = 2
assert net(x, y) == (x, 3, 5)
def test_args_2_kwonlyargs_2_kwarg_used():
"""
Feature: Eliminate Parameter pass should not remove parameters and arguments all used.
Description: Function without unused parameters.
Expectation: compile success and result == (x, y, only1, only2, kw1)
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return (args[0], args[1], only1, only2, kwargs["kw1"])
def construct(self, x, y):
ret = self.trivial(x, y, kw1=5)
return ret
net = Net()
x = 1
y = 2
assert net(x, y) == (x, y, 3, 4, 5)