!31745 Enable if parallel call flag by default
Merge pull request !31745 from xychow/enable-if-parallel-call-by-default
This commit is contained in:
commit
43fd864c10
|
@ -124,11 +124,33 @@ static inline std::pair<mindspore::HashSet<size_t>, mindspore::HashMap<size_t, s
|
|||
}
|
||||
// Erase unused parameters.
|
||||
std::vector<AnfNodePtr> new_parameters;
|
||||
const auto &var_arg_node = fg->GetVariableArgParameter();
|
||||
const auto &kw_arg_node = fg->GetVariableKwargParameter();
|
||||
const auto &kw_only_args = fg->GetKwOnlyArgsParameters();
|
||||
for (size_t i = 0; i < parameters.size(); i++) {
|
||||
const auto ¶m_i = parameters[i];
|
||||
if (unused_parameter_indexes.find(i) == unused_parameter_indexes.end()) {
|
||||
(void)new_parameters.emplace_back(parameters[i]);
|
||||
(void)new_parameters.emplace_back(param_i);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Erase parameter:" << parameters[i]->DebugString() << ",index:" << i;
|
||||
// VarArgs, KwArgs, KwOnlyArgs may not following the index as the Positional Arguments.
|
||||
if (param_i == var_arg_node) {
|
||||
fg->set_has_vararg(false);
|
||||
(void)unused_parameter_indexes.erase(i);
|
||||
} else if (param_i == kw_arg_node) {
|
||||
fg->set_has_kwarg(false);
|
||||
(void)unused_parameter_indexes.erase(i);
|
||||
} else {
|
||||
bool is_kw_only_arg = std::any_of(kw_only_args.cbegin(), kw_only_args.cend(),
|
||||
[param_i](const auto &kw_only_arg) { return kw_only_arg == param_i; });
|
||||
if (is_kw_only_arg) {
|
||||
if (fg->kwonlyargs_count() <= 0) {
|
||||
MS_LOG(EXCEPTION) << "The kw_only_args_count is 0 when a kw_only_arg should be removed";
|
||||
}
|
||||
fg->set_kwonlyargs_count(fg->kwonlyargs_count() - 1);
|
||||
(void)unused_parameter_indexes.erase(i);
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Erase parameter:" << param_i->DebugString() << ", index:" << i;
|
||||
}
|
||||
}
|
||||
manager->SetParameters(fg, new_parameters);
|
||||
|
@ -136,7 +158,7 @@ static inline std::pair<mindspore::HashSet<size_t>, mindspore::HashMap<size_t, s
|
|||
}
|
||||
|
||||
// Adjust the call arguments of func graph whose parameter's eliminated.
|
||||
static inline void AdjustCallerArgs(const CNodePtr &caller,
|
||||
static inline void AdjustCallerArgs(const FuncGraphPtr &called, const CNodePtr &caller,
|
||||
const mindspore::HashSet<size_t> &unused_parameter_indexes) {
|
||||
const FuncGraphManagerPtr &manager = caller->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
@ -148,6 +170,18 @@ static inline void AdjustCallerArgs(const CNodePtr &caller,
|
|||
MS_LOG(DEBUG) << "Erase arg:" << caller->inputs()[i + 1]->DebugString() << ",index:" << i;
|
||||
}
|
||||
}
|
||||
// Remove any Args which may be packed into VarArgs if VarArgs is not used in called FuncGraph;
|
||||
// Note: 1. If there is any *args or key=value argument in call site, it will be converted to unpack_call
|
||||
// CNode. So in this direct call case, all arguments should be plain arguments.
|
||||
// 2. The arguments in caller may be less than the formal parameters in called as some parameters can have
|
||||
// default value.
|
||||
if (!called->has_vararg() &&
|
||||
caller->inputs().size() > (1 + called->GetPositionalArgsCount() + called->hyper_param_count())) {
|
||||
size_t start_offset = called->GetPositionalArgsCount() + 1;
|
||||
size_t end_offset = called->hyper_param_count();
|
||||
new_args.erase(new_args.begin() + start_offset, new_args.end() - end_offset);
|
||||
}
|
||||
|
||||
TraceGuard trace_guard(std::make_shared<TraceCopy>(caller->debug_info()));
|
||||
auto new_caller = caller->func_graph()->NewCNode(new_args);
|
||||
new_caller->set_abstract(caller->abstract());
|
||||
|
@ -229,7 +263,7 @@ class ParameterEliminator {
|
|||
AdjustGetItemCall(caller, only_return_parameter_indexes);
|
||||
}
|
||||
// Erase the arguments for eliminated parameters.
|
||||
AdjustCallerArgs(caller, unused_parameter_indexes);
|
||||
AdjustCallerArgs(fg, caller, unused_parameter_indexes);
|
||||
}
|
||||
changes = true;
|
||||
}
|
||||
|
|
|
@ -538,7 +538,7 @@ bool InferenceOptPrepareAction(const ResourcePtr &res) {
|
|||
}
|
||||
|
||||
bool EliminateUnusedParameterAction(const ResourcePtr &res) {
|
||||
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_IF_PARALLEL_CALL") == "1");
|
||||
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_IF_PARALLEL_CALL") != "0");
|
||||
static const auto transform_for_half_unroll_call = (common::GetEnv("MS_DEV_FOR_HALF_UNROLL") == "1");
|
||||
if (!transform_tail_call_to_parallel_call && !transform_for_half_unroll_call) {
|
||||
return true;
|
||||
|
|
|
@ -245,22 +245,61 @@ void Parser::LiftIfBranchGraphFV() {
|
|||
}
|
||||
|
||||
namespace {
|
||||
bool IsDependOfIsolatedNodes(const AnfNodePtr &node) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
if (cnode == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto attr_sort_rhs_first = cnode->GetAttr(kAttrTopoSortRhsFirst);
|
||||
auto sort_rhs_first =
|
||||
attr_sort_rhs_first != nullptr && attr_sort_rhs_first->isa<BoolImm>() && GetValue<bool>(attr_sort_rhs_first);
|
||||
return sort_rhs_first;
|
||||
}
|
||||
|
||||
std::pair<CNodePtr, AnfNodePtr> GetRealOutputNodes(const FuncGraphPtr &call_graph) {
|
||||
auto graph_output = call_graph->output();
|
||||
if (graph_output == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "graph_output is null, call_graph: " << call_graph->ToString();
|
||||
}
|
||||
auto graph_output_cnode = dyn_cast<CNode>(graph_output);
|
||||
MS_EXCEPTION_IF_NULL(graph_output_cnode);
|
||||
// If output cnode is not the tail call but a Depend CNode, keep the dependency node for later use.
|
||||
AnfNodePtr graph_dependency_node = nullptr;
|
||||
if (IsDependOfIsolatedNodes(graph_output_cnode)) {
|
||||
auto graph_real_output_cnode = dyn_cast<CNode>(graph_output_cnode->input(1));
|
||||
// Get the dependency node;
|
||||
constexpr auto dependency_node_index = 2;
|
||||
graph_dependency_node = graph_output_cnode->input(dependency_node_index);
|
||||
MS_EXCEPTION_IF_NULL(graph_real_output_cnode);
|
||||
graph_output_cnode = graph_real_output_cnode;
|
||||
}
|
||||
return {graph_output_cnode, graph_dependency_node};
|
||||
}
|
||||
|
||||
void TransformParallelCallFormerToMiddle(const FuncGraphPtr &former_call_graph, const FuncGraphPtr &latter_call_graph,
|
||||
size_t middle_graph_output_cnode_size, bool use_arguments_pack) {
|
||||
// The 'former_graph_output' is middle graph call.
|
||||
auto former_graph_output = former_call_graph->output();
|
||||
MS_EXCEPTION_IF_NULL(former_graph_output);
|
||||
// The 'former_graph_output' is middle graph call or depend.
|
||||
const auto &[former_graph_output_cnode, former_graph_dependency_node] = GetRealOutputNodes(former_call_graph);
|
||||
MS_EXCEPTION_IF_NULL(former_graph_output_cnode);
|
||||
std::vector<AnfNodePtr> inputs({NewValueNode(latter_call_graph)});
|
||||
if (use_arguments_pack) {
|
||||
for (size_t i = 0; i < middle_graph_output_cnode_size - 1; ++i) {
|
||||
auto getitem_input = former_call_graph->NewCNodeInOrder(
|
||||
{NewValueNode(prim::kPrimTupleGetItem), former_graph_output, NewValueNode(SizeToLong(i))});
|
||||
{NewValueNode(prim::kPrimTupleGetItem), former_graph_output_cnode, NewValueNode(SizeToLong(i))});
|
||||
(void)inputs.emplace_back(getitem_input);
|
||||
}
|
||||
} else {
|
||||
(void)inputs.emplace_back(former_graph_output);
|
||||
(void)inputs.emplace_back(former_graph_output_cnode);
|
||||
}
|
||||
auto new_output = former_call_graph->NewCNodeBefore(former_call_graph->return_node(), std::move(inputs));
|
||||
if (former_graph_dependency_node != nullptr) {
|
||||
// Adjust the former funcgraph output with Depend.
|
||||
new_output = former_call_graph->NewCNodeAfter(
|
||||
new_output, {NewValueNode(prim::kPrimDepend), new_output, former_graph_dependency_node});
|
||||
}
|
||||
former_call_graph->set_output(new_output);
|
||||
}
|
||||
|
||||
|
@ -287,86 +326,136 @@ bool TransformParallelCallMiddleToLatter(const FuncGraphPtr &middle_call_graph,
|
|||
return use_arguments_pack;
|
||||
}
|
||||
|
||||
bool IsDependOfIsolatedNodes(const AnfNodePtr &node) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
||||
return false;
|
||||
bool IsValueContainScalar(const ValuePtr &value) {
|
||||
if (value->isa<Scalar>()) {
|
||||
return true;
|
||||
}
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
if (cnode == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto attr_sort_rhs_first = cnode->GetAttr(kAttrTopoSortRhsFirst);
|
||||
auto sort_rhs_first =
|
||||
attr_sort_rhs_first != nullptr && attr_sort_rhs_first->isa<BoolImm>() && GetValue<bool>(attr_sort_rhs_first);
|
||||
return sort_rhs_first;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::pair<CNodePtr, AnfNodePtr> GetRealMiddleOutputNodes(const FuncGraphPtr &middle_call_graph) {
|
||||
auto middle_graph_output = middle_call_graph->output();
|
||||
if (middle_graph_output == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "middle_graph_output is null, middle_call_graph: " << middle_call_graph->ToString();
|
||||
bool IsOutputContainScalar(const CNodePtr &output_cnode) {
|
||||
return std::any_of(output_cnode->inputs().cbegin() + 1, output_cnode->inputs().end(), [](const AnfNodePtr &node) {
|
||||
if (node->isa<ValueNode>()) {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
return IsValueContainScalar(value_node->value());
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
bool CheckMiddleGraphOutputContainScalar(
|
||||
const std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> ¶llel_call_vec) {
|
||||
std::vector<bool> contains_scalar;
|
||||
for (auto &call_graphs_pair : parallel_call_vec) {
|
||||
MS_EXCEPTION_IF_NULL(call_graphs_pair.second);
|
||||
auto middle_call_graph = call_graphs_pair.second->func_graph();
|
||||
constexpr auto recur_2 = 2;
|
||||
const auto &middle_graph_output_pair = GetRealOutputNodes(middle_call_graph);
|
||||
const auto middle_graph_output_cnode = middle_graph_output_pair.first;
|
||||
auto middle_graph_output_cnode_size = middle_graph_output_cnode->inputs().size();
|
||||
if (middle_graph_output_cnode_size <= 1) {
|
||||
MS_LOG(DEBUG) << "CNode's inputs size should exceed 1, " << middle_graph_output_cnode->DebugString(recur_2);
|
||||
return false;
|
||||
}
|
||||
|
||||
static const auto transform_if_const_scalar = (common::GetEnv("MS_DEV_IF_PARALLEL_CALL") == "2");
|
||||
if (!transform_if_const_scalar && IsOutputContainScalar(middle_graph_output_cnode)) {
|
||||
MS_LOG(DEBUG) << "CNode's inputs contain const scalar, " << middle_graph_output_cnode->DebugString(recur_2);
|
||||
contains_scalar.push_back(true);
|
||||
} else {
|
||||
contains_scalar.push_back(false);
|
||||
}
|
||||
}
|
||||
auto middle_graph_output_cnode = dyn_cast<CNode>(middle_graph_output);
|
||||
MS_EXCEPTION_IF_NULL(middle_graph_output_cnode);
|
||||
// If latter_call_graph is not the tail call in middle funcgraph, keep the dependency node for later use.
|
||||
AnfNodePtr middle_graph_dependency_node = nullptr;
|
||||
if (IsDependOfIsolatedNodes(middle_graph_output_cnode)) {
|
||||
auto middle_graph_real_output_cnode = dyn_cast<CNode>(middle_graph_output_cnode->input(1));
|
||||
// Get the dependency node;
|
||||
constexpr auto dependency_node_index = 2;
|
||||
middle_graph_dependency_node = middle_graph_output_cnode->input(dependency_node_index);
|
||||
MS_EXCEPTION_IF_NULL(middle_graph_real_output_cnode);
|
||||
middle_graph_output_cnode = middle_graph_real_output_cnode;
|
||||
|
||||
return std::all_of(contains_scalar.cbegin(), contains_scalar.cend(), [](bool is_scalar) { return is_scalar; });
|
||||
}
|
||||
|
||||
bool CheckMiddleGraphOutputPyInterpret(
|
||||
const std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> ¶llel_call_vec) {
|
||||
bool contain_py_interpret = false;
|
||||
for (auto &call_graphs_pair : parallel_call_vec) {
|
||||
MS_EXCEPTION_IF_NULL(call_graphs_pair.second);
|
||||
auto middle_call_graph = call_graphs_pair.second->func_graph();
|
||||
constexpr auto recur_2 = 2;
|
||||
const auto &middle_graph_output_pair = GetRealOutputNodes(middle_call_graph);
|
||||
const auto middle_graph_output_cnode = middle_graph_output_pair.first;
|
||||
auto middle_graph_output_cnode_size = middle_graph_output_cnode->inputs().size();
|
||||
if (middle_graph_output_cnode_size <= 1) {
|
||||
MS_LOG(DEBUG) << "CNode's inputs size should exceed 1, " << middle_graph_output_cnode->DebugString(recur_2);
|
||||
return false;
|
||||
}
|
||||
|
||||
contain_py_interpret |=
|
||||
std::any_of(middle_graph_output_cnode->inputs().cbegin() + 1, middle_graph_output_cnode->inputs().cend(),
|
||||
[](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimPyInterpret); });
|
||||
if (contain_py_interpret) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return {middle_graph_output_cnode, middle_graph_dependency_node};
|
||||
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Transform tail call to parallel call.
|
||||
void Parser::TransformParallelCall() {
|
||||
std::unordered_set<FuncGraphPtr> latter_call_graphs_set;
|
||||
for (auto &call_graphs_pair : parallel_call_graphs_) {
|
||||
MS_EXCEPTION_IF_NULL(call_graphs_pair.first);
|
||||
auto former_call_graph = call_graphs_pair.first->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(call_graphs_pair.second);
|
||||
auto middle_call_graph = call_graphs_pair.second->func_graph();
|
||||
// Transform the call of {middle_graph -> latter_graph}.
|
||||
auto middle_graph_return = middle_call_graph->get_return();
|
||||
if (middle_graph_return == nullptr) {
|
||||
MS_LOG(INFO) << "middle_graph_return is null, middle_call_graph: " << middle_call_graph->ToString();
|
||||
for (auto ¶llel_call_vec : parallel_call_graphs_) {
|
||||
bool all_middle_graphs_output_scalar = CheckMiddleGraphOutputContainScalar(parallel_call_vec);
|
||||
if (all_middle_graphs_output_scalar) {
|
||||
MS_LOG(DEBUG) << "All middle func graph's output contain const scalar, cannot transform to Parallel_If.";
|
||||
continue;
|
||||
}
|
||||
constexpr auto recur_3 = 3;
|
||||
MS_LOG(DEBUG) << "Tail call graphs return: {former: " << former_call_graph->get_return()->DebugString(recur_3)
|
||||
<< ", middle: " << middle_call_graph->get_return()->DebugString(recur_3) << "}";
|
||||
const auto &[middle_graph_output_cnode, middle_graph_dependency_node] = GetRealMiddleOutputNodes(middle_call_graph);
|
||||
auto middle_graph_output_cnode_size = middle_graph_output_cnode->inputs().size();
|
||||
if (middle_graph_output_cnode_size <= 1) {
|
||||
// After Join, Value in Abstract of PyInterpret CNode will be kAnyValue, it cannot be PyInterpreted again, so
|
||||
// ignore the transformation.
|
||||
bool is_middle_graphs_output_py_interpret = CheckMiddleGraphOutputPyInterpret(parallel_call_vec);
|
||||
if (is_middle_graphs_output_py_interpret) {
|
||||
MS_LOG(DEBUG) << "Middle func graph's output contain PyInterpret CNode, cannot transform to Parallel_If.";
|
||||
continue;
|
||||
}
|
||||
for (auto &call_graphs_pair : parallel_call_vec) {
|
||||
MS_EXCEPTION_IF_NULL(call_graphs_pair.first);
|
||||
auto former_call_graph = call_graphs_pair.first->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(call_graphs_pair.second);
|
||||
auto middle_call_graph = call_graphs_pair.second->func_graph();
|
||||
// Transform the call of {middle_graph -> latter_graph}.
|
||||
auto middle_graph_return = middle_call_graph->get_return();
|
||||
if (middle_graph_return == nullptr) {
|
||||
MS_LOG(INFO) << "middle_graph_return is null, middle_call_graph: " << middle_call_graph->ToString();
|
||||
continue;
|
||||
}
|
||||
constexpr auto recur_3 = 3;
|
||||
constexpr auto recur_2 = 2;
|
||||
MS_LOG(DEBUG) << "CNode's inputs size should exceed 1, " << middle_graph_output_cnode->DebugString(recur_2);
|
||||
continue;
|
||||
}
|
||||
auto latter_graph_node = middle_graph_output_cnode->input(0);
|
||||
bool use_arguments_pack = TransformParallelCallMiddleToLatter(
|
||||
middle_call_graph, middle_graph_output_cnode, middle_graph_dependency_node, middle_graph_output_cnode_size);
|
||||
MS_LOG(DEBUG) << "Tail call graphs return: {former: " << former_call_graph->get_return()->DebugString(recur_3)
|
||||
<< ", middle: " << middle_call_graph->get_return()->DebugString(recur_3) << "}";
|
||||
const auto &[middle_graph_output_cnode, middle_graph_dependency_node] = GetRealOutputNodes(middle_call_graph);
|
||||
auto middle_graph_output_cnode_size = middle_graph_output_cnode->inputs().size();
|
||||
if (middle_graph_output_cnode_size <= 1) {
|
||||
MS_LOG(DEBUG) << "CNode's inputs size should exceed 1, " << middle_graph_output_cnode->DebugString(recur_2);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Transform the call of {former_graph -> middle_graph}.
|
||||
auto latter_call_graph = GetValueNode<FuncGraphPtr>(latter_graph_node);
|
||||
if (latter_call_graph == nullptr) {
|
||||
constexpr auto recur_2 = 2;
|
||||
MS_LOG(ERROR) << "The latter graph node is not FuncGraph, " << latter_graph_node->DebugString(recur_2);
|
||||
continue;
|
||||
}
|
||||
if (latter_call_graphs_set.find(latter_call_graph) != latter_call_graphs_set.end()) {
|
||||
MS_LOG(DEBUG) << "The latter graph is handled before, " << latter_call_graph->ToString();
|
||||
continue;
|
||||
}
|
||||
(void)latter_call_graphs_set.emplace(latter_call_graph);
|
||||
TransformParallelCallFormerToMiddle(former_call_graph, latter_call_graph, middle_graph_output_cnode_size,
|
||||
use_arguments_pack);
|
||||
auto latter_graph_node = middle_graph_output_cnode->input(0);
|
||||
bool use_arguments_pack = TransformParallelCallMiddleToLatter(
|
||||
middle_call_graph, middle_graph_output_cnode, middle_graph_dependency_node, middle_graph_output_cnode_size);
|
||||
|
||||
MS_LOG(DEBUG) << "Parallel call graphs return: {former: " << former_call_graph->get_return()->DebugString(recur_3)
|
||||
<< ", middle: " << middle_call_graph->get_return()->DebugString(recur_3) << "}";
|
||||
// Transform the call of {former_graph -> middle_graph}.
|
||||
auto latter_call_graph = GetValueNode<FuncGraphPtr>(latter_graph_node);
|
||||
if (latter_call_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "The latter graph node is not FuncGraph, " << latter_graph_node->DebugString(recur_2);
|
||||
continue;
|
||||
}
|
||||
if (latter_call_graphs_set.find(latter_call_graph) != latter_call_graphs_set.end()) {
|
||||
MS_LOG(DEBUG) << "The latter graph is handled before, " << latter_call_graph->ToString();
|
||||
continue;
|
||||
}
|
||||
(void)latter_call_graphs_set.emplace(latter_call_graph);
|
||||
TransformParallelCallFormerToMiddle(former_call_graph, latter_call_graph, middle_graph_output_cnode_size,
|
||||
use_arguments_pack);
|
||||
|
||||
MS_LOG(DEBUG) << "Parallel call graphs return: {former: " << former_call_graph->get_return()->DebugString(recur_3)
|
||||
<< ", middle: " << middle_call_graph->get_return()->DebugString(recur_3) << "}";
|
||||
}
|
||||
}
|
||||
|
||||
// Lift inner, then lift outer.
|
||||
|
@ -1708,15 +1797,16 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
|
|||
<< ", false_end: " << false_end->ToString() << ", after_block: " << after_block->ToString();
|
||||
(void)ignored_if_latter_call_graphs_.insert(after_block);
|
||||
}
|
||||
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_IF_PARALLEL_CALL") == "1");
|
||||
static const auto transform_tail_call_to_parallel_call = (common::GetEnv("MS_DEV_IF_PARALLEL_CALL") != "0");
|
||||
if (transform_tail_call_to_parallel_call && true_branch_graphs.second != nullptr &&
|
||||
false_branch_graphs.second != nullptr) {
|
||||
true_branch_graphs.first = block;
|
||||
(void)parallel_call_graphs_.emplace_back(true_branch_graphs);
|
||||
MS_LOG(DEBUG) << "Record tail call graphs, true: {former: " << true_branch_graphs.first->func_graph()->ToString()
|
||||
<< ", middle: " << true_branch_graphs.second->func_graph()->ToString() << "}";
|
||||
false_branch_graphs.first = block;
|
||||
(void)parallel_call_graphs_.emplace_back(false_branch_graphs);
|
||||
std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> branch_graphs_vec{true_branch_graphs,
|
||||
false_branch_graphs};
|
||||
(void)parallel_call_graphs_.emplace_back(branch_graphs_vec);
|
||||
MS_LOG(DEBUG) << "Record tail call graphs, false: {former: " << false_branch_graphs.first->func_graph()->ToString()
|
||||
<< ", middle: " << false_branch_graphs.second->func_graph()->ToString() << "}";
|
||||
}
|
||||
|
@ -1987,7 +2077,8 @@ FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py:
|
|||
std::pair<FunctionBlockPtr, FunctionBlockPtr> loop_graphs;
|
||||
loop_graphs.first = body_block;
|
||||
loop_graphs.second = after_body_block;
|
||||
(void)parallel_call_graphs_.emplace_back(loop_graphs);
|
||||
std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> loop_graphs_vec{loop_graphs};
|
||||
(void)parallel_call_graphs_.emplace_back(loop_graphs_vec);
|
||||
MS_LOG(DEBUG) << "Record tail call graphs, loop: {former: " << loop_graphs.first->func_graph()->ToString()
|
||||
<< ", middle: " << loop_graphs.second->func_graph()->ToString() << "}";
|
||||
// Record the rolled body function, for later lifting operation.
|
||||
|
|
|
@ -339,7 +339,7 @@ class Parser {
|
|||
|
||||
// The func graphs to transform tail call ir to independent call ir.
|
||||
// Contains: {former_graph, middle_graph}, latter_graph is no need.
|
||||
std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> parallel_call_graphs_;
|
||||
std::vector<std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>>> parallel_call_graphs_;
|
||||
// The true branch and false branch call info. of if statement.
|
||||
std::vector<std::tuple<CNodePtr, FunctionBlockPtr, FunctionBlockPtr>> if_branch_calls_;
|
||||
// The rolled_body callers info. for later lifting operation.
|
||||
|
|
|
@ -518,6 +518,8 @@ OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &) {
|
|||
OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &) {
|
||||
opt::OptPassConfig control_group = opt::OptPassConfig(opt::irpass::ConvertSwitchReplacement());
|
||||
OptPassGroupMap map({
|
||||
// After CleanAfterOptA, it may need renormalize to eliminate unused elements in Tuple.
|
||||
{"renormalize", opt::OptPassConfig::Renormalize()},
|
||||
{"control_group", control_group},
|
||||
{"renormalize", opt::OptPassConfig::Renormalize()},
|
||||
});
|
||||
|
|
|
@ -664,6 +664,7 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const Fu
|
|||
if (possible_parent_fg != nullptr) {
|
||||
possible_parent_fg->set_flag(kFuncGraphFlagUndetermined, true);
|
||||
MS_LOG(DEBUG) << "Set graph undetermined: " << possible_parent_fg->ToString();
|
||||
return;
|
||||
}
|
||||
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
|
||||
if (fg_eval == nullptr) {
|
||||
|
@ -671,14 +672,7 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const Fu
|
|||
}
|
||||
auto fg = fg_eval->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto fg_parent = fg->parent();
|
||||
if (fg_parent != nullptr) {
|
||||
fg_parent->set_flag(kFuncGraphFlagUndetermined, true);
|
||||
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString() << " for fg: " << fg->ToString();
|
||||
return;
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "cannot find parent for fg: " << fg->ToString();
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "cannot set Undetermined flag for fg: " << fg->ToString();
|
||||
}
|
||||
|
||||
EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators,
|
||||
|
|
|
@ -40,7 +40,7 @@ FuncGraph::FuncGraph(GraphDebugInfoPtr &&debug_info)
|
|||
has_vararg_(false),
|
||||
has_kwarg_(false),
|
||||
exist_multi_target_(false),
|
||||
kwonlyargs_count_(0),
|
||||
kw_only_args_count_(0),
|
||||
hyper_param_count_(0),
|
||||
is_generated_(false),
|
||||
is_bprop_(false),
|
||||
|
@ -569,21 +569,20 @@ AnfNodePtr FuncGraph::GetVariableArgParameter() {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// one vararg + kwarg so the min param num is 2;
|
||||
constexpr size_t min_param_num = 2;
|
||||
size_t min_param_num = 1;
|
||||
if (has_kwarg_) {
|
||||
if (parameters_.size() < hyper_param_count_ + min_param_num) {
|
||||
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
|
||||
<< hyper_param_count_ << ", parameters is less than 2 + hyper_param_count";
|
||||
}
|
||||
return parameters_[(parameters_.size() - hyper_param_count_) - min_param_num];
|
||||
min_param_num += 1;
|
||||
}
|
||||
min_param_num += kw_only_args_count_;
|
||||
min_param_num += hyper_param_count_;
|
||||
|
||||
if (parameters_.size() < hyper_param_count_ + 1) {
|
||||
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
|
||||
<< hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
|
||||
if (parameters_.size() < min_param_num) {
|
||||
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size()
|
||||
<< " which less than the sum of following: hyper_param_count: " << hyper_param_count_
|
||||
<< ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_
|
||||
<< ", kw_only_args_count_: " << kw_only_args_count_;
|
||||
}
|
||||
return parameters_[(parameters_.size() - hyper_param_count_) - 1];
|
||||
return parameters_[parameters_.size() - min_param_num + kw_only_args_count_];
|
||||
}
|
||||
|
||||
std::string FuncGraph::GetVariableArgName() {
|
||||
|
@ -591,24 +590,9 @@ std::string FuncGraph::GetVariableArgName() {
|
|||
return "";
|
||||
}
|
||||
|
||||
// one vararg + kwarg so the min param num is 2;
|
||||
constexpr size_t min_param_num = 2;
|
||||
if (has_kwarg_) {
|
||||
if (parameters_.size() < hyper_param_count_ + min_param_num) {
|
||||
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
|
||||
<< hyper_param_count_ << ", parameters is less than 2 + hyper_param_count";
|
||||
}
|
||||
const auto ¶meter =
|
||||
parameters_[(parameters_.size() - hyper_param_count_) - min_param_num]->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
return parameter->name();
|
||||
}
|
||||
|
||||
if (parameters_.size() < hyper_param_count_ + 1) {
|
||||
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
|
||||
<< hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
|
||||
}
|
||||
const auto ¶meter = parameters_[(parameters_.size() - hyper_param_count_) - 1]->cast<ParameterPtr>();
|
||||
const auto ¶m_node = GetVariableArgParameter();
|
||||
MS_EXCEPTION_IF_NULL(param_node);
|
||||
const auto ¶meter = param_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
return parameter->name();
|
||||
}
|
||||
|
@ -637,6 +621,37 @@ std::string FuncGraph::GetVariableKwargName() {
|
|||
return "";
|
||||
}
|
||||
|
||||
AnfNodePtrList FuncGraph::GetKwOnlyArgsParameters() {
|
||||
AnfNodePtrList kw_only_args;
|
||||
if (kw_only_args_count_ == 0) {
|
||||
return kw_only_args;
|
||||
}
|
||||
|
||||
size_t min_param_num = 0;
|
||||
size_t varargs_kwargs_num = 0;
|
||||
if (has_vararg_) {
|
||||
min_param_num += 1;
|
||||
varargs_kwargs_num += 1;
|
||||
}
|
||||
if (has_kwarg_) {
|
||||
min_param_num += 1;
|
||||
varargs_kwargs_num += 1;
|
||||
}
|
||||
min_param_num += kw_only_args_count_;
|
||||
min_param_num += hyper_param_count_;
|
||||
|
||||
if (parameters_.size() < min_param_num) {
|
||||
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size()
|
||||
<< " which less than the sum of following: hyper_param_count: " << hyper_param_count_
|
||||
<< ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_
|
||||
<< ", kw_only_args_count: " << kw_only_args_count_;
|
||||
}
|
||||
size_t kw_only_args_start_offset = parameters_.size() - min_param_num;
|
||||
std::copy(parameters_.cbegin() + kw_only_args_start_offset,
|
||||
parameters_.cend() - hyper_param_count_ - varargs_kwargs_num, std::back_inserter(kw_only_args));
|
||||
return kw_only_args;
|
||||
}
|
||||
|
||||
int FuncGraph::GetPositionalArgsCount() const {
|
||||
int count = SizeToInt(parameters_.size());
|
||||
if (has_kwarg_) {
|
||||
|
@ -645,7 +660,7 @@ int FuncGraph::GetPositionalArgsCount() const {
|
|||
if (has_vararg_) {
|
||||
count--;
|
||||
}
|
||||
return (count - kwonlyargs_count_) - SizeToInt(hyper_param_count_);
|
||||
return (count - kw_only_args_count_) - SizeToInt(hyper_param_count_);
|
||||
}
|
||||
|
||||
AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) {
|
||||
|
|
|
@ -165,14 +165,16 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
std::map<std::string, AnfNodePtr> ¶meter_default_value() { return parameter_default_value_; }
|
||||
void set_has_vararg(bool has_) { has_vararg_ = has_; }
|
||||
bool has_vararg() const { return has_vararg_; }
|
||||
// Parameters are ordered as: Positional Parameters, Kwonlyargs, *Varargs, **Kwargs, HyperParam;
|
||||
AnfNodePtr GetVariableArgParameter();
|
||||
std::string GetVariableArgName();
|
||||
void set_has_kwarg(bool has_) { has_kwarg_ = has_; }
|
||||
bool has_kwarg() const { return has_kwarg_; }
|
||||
void set_kwonlyargs_count(int count) { kwonlyargs_count_ = count; }
|
||||
int kwonlyargs_count() const { return kwonlyargs_count_; }
|
||||
void set_kwonlyargs_count(int count) { kw_only_args_count_ = count; }
|
||||
int kwonlyargs_count() const { return kw_only_args_count_; }
|
||||
AnfNodePtr GetVariableKwargParameter();
|
||||
std::string GetVariableKwargName();
|
||||
AnfNodePtrList GetKwOnlyArgsParameters();
|
||||
void set_hyper_param_count(size_t count) { hyper_param_count_ = count; }
|
||||
size_t hyper_param_count() const { return hyper_param_count_; }
|
||||
int GetPositionalArgsCount() const;
|
||||
|
@ -412,11 +414,11 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
std::vector<AnfNodePtr> parameters_;
|
||||
std::vector<AnfNodePtr> paramter_obj_nodes_;
|
||||
|
||||
// Whether there is a *args and **kwargs, and count kwonlyargs'number.
|
||||
// Whether there is a *args and **kwargs, and count kw_only_args'number.
|
||||
bool has_vararg_;
|
||||
bool has_kwarg_;
|
||||
bool exist_multi_target_;
|
||||
int kwonlyargs_count_;
|
||||
int kw_only_args_count_;
|
||||
// Hyper param is placed on the top graph,
|
||||
// and positioned in the end of the param list, so we record the number to trace the position.
|
||||
size_t hyper_param_count_;
|
||||
|
|
|
@ -543,9 +543,6 @@ class _TrainPipelineWithLossScaleCell(TrainOneStepCell):
|
|||
overflow = cond
|
||||
if self.loss_scaling_manager is not None:
|
||||
overflow = self.loss_scaling_manager(self.scale_sense, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
ret = (loss, overflow, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
if not overflow:
|
||||
self.optimizer(grads)
|
||||
return (loss, overflow, scaling_sens)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
"""control_ops"""
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Validator as validator, Rel
|
||||
from ...common import dtype as mstype
|
||||
|
||||
|
||||
|
@ -74,7 +74,7 @@ class GeSwitch(PrimitiveWithInfer):
|
|||
raise NotImplementedError
|
||||
|
||||
def infer_shape(self, data, pred):
|
||||
validator.check_equal_int(len(pred), 0, "pred rank", self.name)
|
||||
validator.check_int_range(len(pred), 0, 1, Rel.INC_BOTH, "pred rank", self.name)
|
||||
return data, data
|
||||
|
||||
def infer_dtype(self, data_type, pred_type):
|
||||
|
|
|
@ -106,7 +106,7 @@ def test_if_in_for_tensor_4():
|
|||
@ms_function
|
||||
def control_flow_for():
|
||||
x = Tensor(7)
|
||||
y = Tensor(0)
|
||||
y = Tensor(0.0)
|
||||
for _ in range(3):
|
||||
x = x + y/2
|
||||
if y < Tensor(10) and x < Tensor(20):
|
||||
|
@ -131,7 +131,7 @@ def test_if_in_for_tensor_5():
|
|||
@ms_function
|
||||
def control_flow_for():
|
||||
x = Tensor(7)
|
||||
y = Tensor(0)
|
||||
y = Tensor(0.0)
|
||||
for _ in range(3):
|
||||
x = x + y/2
|
||||
if y < Tensor(10):
|
||||
|
@ -141,7 +141,7 @@ def test_if_in_for_tensor_5():
|
|||
y += Tensor(1)
|
||||
return x + y
|
||||
res = control_flow_for()
|
||||
assert res == 60
|
||||
assert res == 62
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
@ -95,6 +96,7 @@ class Net(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Working on it in Parallel')
|
||||
def test_control_flow():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
|
|
|
@ -29,9 +29,6 @@ grad_all_with_sens = C.GradOperation(sens_param=True)
|
|||
|
||||
def test_parser_three_default_mixed_args_subnet():
|
||||
class SubNetDefaultMixedArgs(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, y, x=3, x1=None, x2=(1, 2)):
|
||||
if x == 3:
|
||||
if x1 == None:
|
||||
|
@ -66,9 +63,6 @@ def test_net_vararg_kwonlyarg_kwarg():
|
|||
return c
|
||||
|
||||
class SecondNet(Cell):
|
||||
def __init__(self):
|
||||
super(SecondNet, self).__init__()
|
||||
|
||||
def construct(self, x, y=2, p=5, q=40, *var, key1=1, key2=3, **kwargs):
|
||||
a = x - y
|
||||
b = p * q
|
||||
|
@ -93,9 +87,6 @@ def test_net_vararg_normal_input():
|
|||
return c
|
||||
|
||||
class SecondNet(Cell):
|
||||
def __init__(self):
|
||||
super(SecondNet, self).__init__()
|
||||
|
||||
def construct(self, x, y=2, p=5, q=40, *var, key1=1, key2=3, **kwargs):
|
||||
a = x - y
|
||||
b = p * q
|
||||
|
@ -162,9 +153,6 @@ def test_no_vararg():
|
|||
return c
|
||||
|
||||
class SecondNet(Cell):
|
||||
def __init__(self):
|
||||
super(SecondNet, self).__init__()
|
||||
|
||||
def construct(self, x, y, *, z=0, r=1):
|
||||
ret = x + y + z + r
|
||||
return ret
|
||||
|
@ -228,9 +216,6 @@ def test_net_vargs_expand():
|
|||
return self.grad(self.network)(*inputs)
|
||||
|
||||
class AddNet(Cell):
|
||||
def __init__(self):
|
||||
super(AddNet, self).__init__()
|
||||
|
||||
def construct(self, x, y):
|
||||
return x + y
|
||||
|
||||
|
@ -289,9 +274,6 @@ def test_mixed_precision_const_parameter():
|
|||
|
||||
def test_pass_args_by_key_ward_way():
|
||||
class KeyWardNet(Cell):
|
||||
def __init__(self):
|
||||
super(KeyWardNet, self).__init__()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
return x + y - z
|
||||
|
||||
|
@ -333,3 +315,350 @@ def test_none_input():
|
|||
x = Tensor(np.array([1, 2, 3, 4]).astype(np.float32).reshape((1, 1, 2, 2,)))
|
||||
net = Net()
|
||||
net(x, (4, 4), None, True)
|
||||
|
||||
|
||||
def test_args_kwarg_not_used():
|
||||
"""
|
||||
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are varargs and kwargs properly.
|
||||
Description: Function with unused parameters which are varargs and kwargs.
|
||||
Expectation: compile success and result == 0
|
||||
"""
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, **kwargs):
|
||||
return 0
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = self.trivial(x, y)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
x = 1
|
||||
y = 2
|
||||
assert net(x, y) == 0
|
||||
|
||||
|
||||
def test_args_kwonlyargs_1_kwarg_not_used():
|
||||
"""
|
||||
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are varargs, kwonlyargs and
|
||||
kwargs properly.
|
||||
Description: Function with unused parameters which are varargs, 1 kwonlyargs and kwargs.
|
||||
Expectation: compile success and result == 0
|
||||
"""
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, **kwargs):
|
||||
return 0
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = self.trivial(x, y)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
x = 1
|
||||
y = 2
|
||||
assert net(x, y) == 0
|
||||
|
||||
|
||||
def test_args_kwonlyargs_2_kwarg_not_used():
|
||||
"""
|
||||
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are varargs, kwonlyargs and
|
||||
kwargs properly.
|
||||
Description: Function with unused parameters which are varargs, 2 kwonlyargs and kwargs.
|
||||
Expectation: compile success and result == 0
|
||||
"""
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return 0
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = self.trivial(x, y)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
x = 1
|
||||
y = 2
|
||||
assert net(x, y) == 0
|
||||
|
||||
|
||||
def test_args_1_used_kwonlyargs_kwarg_not_used():
|
||||
"""
|
||||
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are kwonlyargs and
|
||||
kwargs properly.
|
||||
Description: Function with unused parameters which are 1 kwonlyargs and kwargs.
|
||||
Expectation: compile success and result == x
|
||||
"""
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, **kwargs):
|
||||
return args[0]
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = self.trivial(x, y)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
x = 1
|
||||
y = 2
|
||||
assert net(x, y) == x
|
||||
|
||||
|
||||
def test_args_2_used_kwonlyargs_kwarg_not_used():
|
||||
"""
|
||||
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are kwonlyargs and
|
||||
kwargs properly.
|
||||
Description: Function with unused parameters which are 1 kwonlyargs and kwargs.
|
||||
Expectation: compile success and result == y
|
||||
"""
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, **kwargs):
|
||||
return args[1]
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = self.trivial(x, y)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
x = 1
|
||||
y = 2
|
||||
assert net(x, y) == y
|
||||
|
||||
|
||||
def test_kwonlyargs_1_used_args_kwarg_not_used():
|
||||
"""
|
||||
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are varargs and
|
||||
kwargs properly.
|
||||
Description: Function with unused parameters which are varargs and kwargs.
|
||||
Expectation: compile success and result == only1
|
||||
"""
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, **kwargs):
|
||||
return only1
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = self.trivial(x, y)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
x = 1
|
||||
y = 2
|
||||
assert net(x, y) == 3
|
||||
|
||||
|
||||
def test_kwonlyargs_2_used_args_kwarg_not_used():
|
||||
"""
|
||||
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are varargs and
|
||||
kwargs properly.
|
||||
Description: Function with unused parameters which are varargs and kwargs.
|
||||
Expectation: compile success and result == only2
|
||||
"""
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return only2
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = self.trivial(x, y)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
x = 1
|
||||
y = 2
|
||||
assert net(x, y) == 4
|
||||
|
||||
|
||||
def test_kwarg_used_args_kwonlyargs_not_used():
|
||||
"""
|
||||
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are varargs and
|
||||
kwonlyargs properly.
|
||||
Description: Function with unused parameters which are varargs and kwonlyargs.
|
||||
Expectation: compile success and result == kw1
|
||||
"""
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return kwargs["kw1"]
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = self.trivial(x, y, kw1=5)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
x = 1
|
||||
y = 2
|
||||
assert net(x, y) == 5
|
||||
|
||||
|
||||
def test_args_1_kwonlyargs_1_used_kwarg_not_used():
|
||||
"""
|
||||
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are kwargs properly.
|
||||
Description: Function with unused parameters which are kwargs.
|
||||
Expectation: compile success and result == (x, 3)
|
||||
"""
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return (args[0], only1)
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = self.trivial(x, y)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
x = 1
|
||||
y = 2
|
||||
assert net(x, y) == (x, 3)
|
||||
|
||||
|
||||
def test_args_2_kwonlyargs_1_used_kwarg_not_used():
|
||||
"""
|
||||
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are kwargs properly.
|
||||
Description: Function with unused parameters which are kwargs.
|
||||
Expectation: compile success and result == (x, y, 3)
|
||||
"""
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return (args[0], args[1], only1)
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = self.trivial(x, y)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
x = 1
|
||||
y = 2
|
||||
assert net(x, y) == (x, y, 3)
|
||||
|
||||
|
||||
def test_args_2_kwonlyargs_2_used_kwarg_not_used():
|
||||
"""
|
||||
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are kwargs properly.
|
||||
Description: Function with unused parameters which are kwargs.
|
||||
Expectation: compile success and result == (x, y, only1, only2)
|
||||
"""
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return (args[0], args[1], only1, only2)
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = self.trivial(x, y)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
x = 1
|
||||
y = 2
|
||||
assert net(x, y) == (x, y, 3, 4)
|
||||
|
||||
|
||||
def test_kwonlyargs_1_kwarg_used_args_not_used():
|
||||
"""
|
||||
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are varargs properly.
|
||||
Description: Function with unused parameters which are varargs.
|
||||
Expectation: compile success and result == (y, kw1)
|
||||
"""
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return (only1, kwargs["kw1"])
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = self.trivial(x, y, kw1=5)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
x = 1
|
||||
y = 2
|
||||
assert net(x, y) == (3, 5)
|
||||
|
||||
|
||||
def test_kwonlyargs_2_kwarg_used_args_not_used():
|
||||
"""
|
||||
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are varargs properly.
|
||||
Description: Function with unused parameters which are varargs.
|
||||
Expectation: compile success and result == (only1, only2, kw1)
|
||||
"""
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return (only1, only2, kwargs["kw1"])
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = self.trivial(x, y, kw1=5)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
x = 1
|
||||
y = 2
|
||||
assert net(x, y) == (3, 4, 5)
|
||||
|
||||
|
||||
def test_args_1_kwarg_used_kwonlyargs_not_used():
|
||||
"""
|
||||
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are kwonlyargs properly.
|
||||
Description: Function with unused parameters which are kwonlyargs.
|
||||
Expectation: compile success and result == (x, kw1)
|
||||
"""
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return (args[0], kwargs["kw1"])
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = self.trivial(x, y, kw1=5)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
x = 1
|
||||
y = 2
|
||||
assert net(x, y) == (x, 5)
|
||||
|
||||
|
||||
def test_args_2_kwarg_used_kwonlyargs_not_used():
|
||||
"""
|
||||
Feature: Eliminate Parameter pass can remove unused parameters and arguments which are kwonlyargs properly.
|
||||
Description: Function with unused parameters which are kwonlyargs.
|
||||
Expectation: compile success and result == (x, y, kw1)
|
||||
"""
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return (args[0], args[1], kwargs["kw1"])
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = self.trivial(x, y, kw1=5)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
x = 1
|
||||
y = 2
|
||||
assert net(x, y) == (x, y, 5)
|
||||
|
||||
|
||||
def test_args_1_kwonlyargs_1_kwarg_used():
|
||||
"""
|
||||
Feature: Eliminate Parameter pass can remove unused parameters and arguments which is kwonlyarg properly.
|
||||
Description: Function with unused parameters which is kwonlyarg.
|
||||
Expectation: compile success and result == (x, only1, kw1)
|
||||
"""
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return (args[0], only1, kwargs["kw1"])
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = self.trivial(x, y, kw1=5)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
x = 1
|
||||
y = 2
|
||||
assert net(x, y) == (x, 3, 5)
|
||||
|
||||
|
||||
def test_args_2_kwonlyargs_2_kwarg_used():
|
||||
"""
|
||||
Feature: Eliminate Parameter pass should not remove parameters and arguments all used.
|
||||
Description: Function without unused parameters.
|
||||
Expectation: compile success and result == (x, y, only1, only2, kw1)
|
||||
"""
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return (args[0], args[1], only1, only2, kwargs["kw1"])
|
||||
|
||||
def construct(self, x, y):
|
||||
ret = self.trivial(x, y, kw1=5)
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
x = 1
|
||||
y = 2
|
||||
assert net(x, y) == (x, y, 3, 4, 5)
|
||||
|
|
Loading…
Reference in New Issue