From a89b662b62195a462438cc9f57820724a2848fbd Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Mon, 30 May 2022 14:50:27 +0800 Subject: [PATCH] Change the weight parameter to FV parameter in FuncGraph, and add ST cases. --- .../irpass/less_batch_normalization.cc | 8 +- .../optimizer/irpass/parameter_eliminate.h | 4 +- .../cache_embedding/cache_embedding.cc | 20 +- .../pipeline_transformer.cc | 2 +- mindspore/ccsrc/pipeline/jit/parse/resolve.cc | 11 +- mindspore/ccsrc/utils/convert_utils_py.cc | 8 +- mindspore/core/ir/func_graph.cc | 89 +++--- mindspore/core/ir/func_graph.h | 13 +- mindspore/core/ir/func_graph_cloner.cc | 4 +- mindspore/core/ir/func_graph_extends.cc | 10 +- .../core/load_mindir/anf_model_parser.cc | 6 +- .../construct_input/test_outermost_input.py | 290 ++++++++++++++++++ .../test_inner_dyn_shape_ms_function.py | 0 .../test_nested_calling_ms_function.py | 0 .../test_outmost_dyn_shape_ms_function.py | 0 .../test_pynative_lenet_ms_function.py | 0 .../ms_function/test_pynative_ms_function.py | 0 .../test_pynative_outermost_non_tensor.py | 82 ----- .../dynamic_shape/dynamic_shape_pass_test.cc | 2 +- ...st_outermost_net_pass_non_tensor_inputs.py | 130 +------- 20 files changed, 372 insertions(+), 307 deletions(-) create mode 100644 tests/st/construct_input/test_outermost_input.py rename tests/st/{pynative => }/ms_function/test_inner_dyn_shape_ms_function.py (100%) rename tests/st/{pynative => }/ms_function/test_nested_calling_ms_function.py (100%) rename tests/st/{pynative => }/ms_function/test_outmost_dyn_shape_ms_function.py (100%) rename tests/st/{pynative => }/ms_function/test_pynative_lenet_ms_function.py (100%) rename tests/st/{pynative => }/ms_function/test_pynative_ms_function.py (100%) delete mode 100644 tests/st/pynative/non_tensor_input/test_pynative_outermost_non_tensor.py diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.cc b/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.cc index 94ce61e0f3d..4b1ea06f456 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.cc @@ -338,13 +338,13 @@ void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager }), root_parameters.end()); size_t remove_param_count = origin_param_count - root_parameters.size(); - size_t hyper_param_count = root_graph->hyper_param_count(); - if (remove_param_count > hyper_param_count) { + size_t fv_param_count = root_graph->fv_param_count(); + if (remove_param_count > fv_param_count) { MS_LOG(ERROR) << "The number of deleted parameters cannot exceed the number of original parameters."; return; } - hyper_param_count = hyper_param_count - remove_param_count; - root_graph->set_hyper_param_count(hyper_param_count); + fv_param_count = fv_param_count - remove_param_count; + root_graph->set_fv_param_count(fv_param_count); manager->SetParameters(root_graph, root_parameters); } } // namespace diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/parameter_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/parameter_eliminate.h index 74b42b84058..2ffa6f427e0 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/parameter_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/parameter_eliminate.h @@ -176,9 +176,9 @@ static inline void AdjustCallerArgs(const FuncGraphPtr &called, const CNodePtr & // 2. The arguments in caller may be less than the formal parameters in called as some parameters can have // default value. if (!called->has_vararg() && - caller->inputs().size() > (1 + called->GetPositionalArgsCount() + called->hyper_param_count())) { + caller->inputs().size() > (1 + called->GetPositionalArgsCount() + called->fv_param_count())) { size_t start_offset = called->GetPositionalArgsCount() + 1; - size_t end_offset = called->hyper_param_count(); + size_t end_offset = called->fv_param_count(); new_args.erase(new_args.begin() + start_offset, new_args.end() - end_offset); } diff --git a/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc b/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc index d9c302c8356..6b97dba3221 100644 --- a/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc +++ b/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc @@ -52,9 +52,7 @@ ParamMap AddCacheParameters(const FuncGraphPtr &graph, const ParamSet ¶meter auto cache_name = ori_param_name + "_cache"; new_param_info->set_name(cache_name); new_tensor->set_param_info(new_param_info); - auto cache_param = graph->AddWeightParameter(cache_name); - cache_param->set_default_param(MakeValue(new_tensor)); - cache_param->set_abstract(new_tensor->ToAbstract()); + auto cache_param = graph->AddFvParameter(cache_name, new_tensor); cache_host_params_map[cache_param] = param; } } @@ -260,10 +258,7 @@ AnfNodePtr InitHashMap(const FuncGraphPtr &func_graph, const int64_t host_size, std::string hashmap_name = "cache_hashmap"; new_param_info->set_name(hashmap_name); new_tensor->set_param_info(new_param_info); - auto hashmap = func_graph->AddWeightParameter(hashmap_name); - hashmap->set_default_param(MakeValue(new_tensor)); - hashmap->set_abstract(new_tensor->ToAbstract()); - return hashmap; + return func_graph->AddFvParameter(hashmap_name, new_tensor); } AnfNodePtr InitStep(const FuncGraphPtr &func_graph, TypeId type_id) { @@ -273,10 +268,7 @@ AnfNodePtr InitStep(const FuncGraphPtr &func_graph, TypeId type_id) { std::string step_name = "cache_step"; new_param_info->set_name(step_name); new_tensor->set_param_info(new_param_info); - auto step = func_graph->AddWeightParameter(step_name); - step->set_default_param(MakeValue(new_tensor)); - step->set_abstract(new_tensor->ToAbstract()); - return step; + return func_graph->AddFvParameter(step_name, new_tensor); } AnfNodePtr CreateMapCacheIdx(const FuncGraphPtr &func_graph, const AnfNodePtr &indices, @@ -540,11 +532,7 @@ AnfNodePtr CreateOutputNodeParam(const FuncGraphPtr &graph, const AnfNodePtr &or auto new_param_name = name + "_pipe"; new_param_info->set_name(new_param_name); new_tensor->set_param_info(new_param_info); - auto new_param = graph->AddWeightParameter(new_param_name); - new_param->set_default_param(MakeValue(new_tensor)); - auto abs_tensor = new_tensor->ToAbstract(); - new_param->set_abstract(abs_tensor); - return new_param->cast(); + return graph->AddFvParameter(new_param_name, new_tensor); } AnfMap CreateOtherPipeParams(const FuncGraphPtr &graph, const AnfSet &no_ref_params) { diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc index 8683d91043a..880bf98bf32 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc @@ -1085,7 +1085,7 @@ void PipelineTransformer::ModifyParameterList() { } } auto del_num = parameters.size() - parameter_list.size(); - root_->set_hyper_param_count(root_->hyper_param_count() - del_num); + root_->set_fv_param_count(root_->fv_param_count() - del_num); manager_->SetParameters(root_, parameter_list); } } // namespace parallel diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index 598accbfd79..0d23f1d9404 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -175,14 +175,9 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object } } if (para_node == nullptr) { - auto node = top_func_graph->AddWeightParameter(param_name); auto value = py::cast(obj); + para_node = top_func_graph->AddFvParameter(param_name, value); param_obj_ids.emplace_back(obj_id); - node->set_default_param(value); - // Set abstract for parameter - auto abs = value->ToAbstract(); - node->set_abstract(abs); - para_node = node; MS_LOG(DEBUG) << "Created a new weight parameter for " << func_graph->ToString() << ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString(); } @@ -224,8 +219,8 @@ void ConvertLoadedGraph(const FuncGraphPtr &func_graph, const ValuePtr &value) { // Update top_graph top_graph->add_parameter(param_ptr); - size_t hyper_param_count = top_graph->hyper_param_count(); - top_graph->set_hyper_param_count(hyper_param_count + 1); + size_t fv_param_count = top_graph->fv_param_count(); + top_graph->set_fv_param_count(fv_param_count + 1); } else { input_params.push_back(param_ptr); } diff --git a/mindspore/ccsrc/utils/convert_utils_py.cc b/mindspore/ccsrc/utils/convert_utils_py.cc index 3b8f8682f50..38812886341 100644 --- a/mindspore/ccsrc/utils/convert_utils_py.cc +++ b/mindspore/ccsrc/utils/convert_utils_py.cc @@ -469,8 +469,8 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple auto func_graph = output->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); auto params = func_graph->parameters(); - if ((args.size() + func_graph->hyper_param_count()) != params.size()) { - MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->hyper_param_count() + if ((args.size() + func_graph->fv_param_count()) != params.size()) { + MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->fv_param_count() << " not equal to graph input size " << params.size() << ", let graph to be executed."; } @@ -479,9 +479,9 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters"; } size_t index = it - params.cbegin(); - if (index >= args.size() + func_graph->hyper_param_count()) { + if (index >= args.size() + func_graph->fv_param_count()) { MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size() - << " add Parameter count " << func_graph->hyper_param_count() << "."; + << " add Parameter count " << func_graph->fv_param_count() << "."; } if (index < args.size()) { *ret_val = args[index]; diff --git a/mindspore/core/ir/func_graph.cc b/mindspore/core/ir/func_graph.cc index 7d55f6b5a46..3b31c226960 100644 --- a/mindspore/core/ir/func_graph.cc +++ b/mindspore/core/ir/func_graph.cc @@ -41,7 +41,7 @@ FuncGraph::FuncGraph(GraphDebugInfoPtr &&debug_info) has_kwarg_(false), exist_multi_target_(false), kw_only_args_count_(0), - hyper_param_count_(0), + fv_param_count_(0), is_generated_(false), return_(nullptr), manager_(), @@ -91,54 +91,56 @@ const std::vector FuncGraph::get_inputs() const { ParameterPtr FuncGraph::add_parameter() { FuncGraphPtr this_func_graph = shared_from_base(); - ParameterPtr p = std::make_shared(this_func_graph); - add_parameter(p); - return p; + ParameterPtr param = std::make_shared(this_func_graph); + add_parameter(param); + return param; } ParameterPtr FuncGraph::add_parameter(NodeDebugInfoPtr &&debug_info) { FuncGraphPtr this_func_graph = shared_from_base(); - ParameterPtr p = std::make_shared(this_func_graph, std::move(debug_info)); - add_parameter(p); - return p; + ParameterPtr param = std::make_shared(this_func_graph, std::move(debug_info)); + add_parameter(param); + return param; } -void FuncGraph::add_parameter(const ParameterPtr &p) { +void FuncGraph::add_parameter(const ParameterPtr ¶m) { if (manager_.lock()) { - manager_.lock()->AddParameter(shared_from_base(), p); + manager_.lock()->AddParameter(shared_from_base(), param); } else { - parameters_.push_back(p); + parameters_.push_back(param); } } ParameterPtr FuncGraph::InsertFrontParameter() { FuncGraphPtr this_func_graph = shared_from_base(); - ParameterPtr p = std::make_shared(this_func_graph); - InsertFrontParameter(p); - return p; + ParameterPtr param = std::make_shared(this_func_graph); + InsertFrontParameter(param); + return param; } -void FuncGraph::InsertFrontParameter(const ParameterPtr &p) { +void FuncGraph::InsertFrontParameter(const ParameterPtr ¶m) { if (manager_.lock()) { - manager_.lock()->InsertFrontParameter(shared_from_base(), p); + manager_.lock()->InsertFrontParameter(shared_from_base(), param); } else { - PrependParameter(p); + PrependParameter(param); } } -ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { +ParameterPtr FuncGraph::AddFvParameter(const std::string &name, const ValuePtr &default_value) { FuncGraphPtr this_graph = shared_from_base(); - ParameterPtr p = std::make_shared(this_graph); - p->set_name(name); - p->debug_info()->set_name(name); - + ParameterPtr param = std::make_shared(this_graph); + param->set_name(name); + param->debug_info()->set_name(name); + MS_EXCEPTION_IF_NULL(default_value); + param->set_default_param(default_value); + param->set_abstract(default_value->ToAbstract()); if (manager_.lock()) { - manager_.lock()->AddParameter(shared_from_base(), p); + manager_.lock()->AddParameter(shared_from_base(), param); } else { - parameters_.push_back(p); + parameters_.push_back(param); } - hyper_param_count_++; - return p; + ++fv_param_count_; + return param; } bool FuncGraph::has_flag(const std::string &key) const { @@ -573,11 +575,11 @@ AnfNodePtr FuncGraph::GetVariableArgParameter() { min_param_num += 1; } min_param_num += kw_only_args_count_; - min_param_num += hyper_param_count_; + min_param_num += fv_param_count_; if (parameters_.size() < min_param_num) { MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() - << " which less than the sum of following: hyper_param_count: " << hyper_param_count_ + << " which less than the sum of following: fv_param_count: " << fv_param_count_ << ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_ << ", kw_only_args_count_: " << kw_only_args_count_; } @@ -598,22 +600,22 @@ std::string FuncGraph::GetVariableArgName() { AnfNodePtr FuncGraph::GetVariableKwargParameter() { if (has_kwarg_) { - if (parameters_.size() < hyper_param_count_ + 1) { - MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " - << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; + if (parameters_.size() < fv_param_count_ + 1) { + MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", fv_param_count is " << fv_param_count_ + << ", parameters is less than 1 + fv_param_count"; } - return parameters_[(parameters_.size() - hyper_param_count_) - 1]; + return parameters_[(parameters_.size() - fv_param_count_) - 1]; } return nullptr; } std::string FuncGraph::GetVariableKwargName() { if (has_kwarg_) { - if (parameters_.size() < hyper_param_count_ + 1) { - MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " - << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; + if (parameters_.size() < fv_param_count_ + 1) { + MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", fv_param_count is " << fv_param_count_ + << ", parameters is less than 1 + fv_param_count"; } - const auto ¶meter = parameters_[(parameters_.size() - hyper_param_count_) - 1]->cast(); + const auto ¶meter = parameters_[(parameters_.size() - fv_param_count_) - 1]->cast(); MS_EXCEPTION_IF_NULL(parameter); return parameter->name(); } @@ -637,17 +639,17 @@ AnfNodePtrList FuncGraph::GetKwOnlyArgsParameters() { varargs_kwargs_num += 1; } min_param_num += kw_only_args_count_; - min_param_num += hyper_param_count_; + min_param_num += fv_param_count_; if (parameters_.size() < min_param_num) { MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() - << " which less than the sum of following: hyper_param_count: " << hyper_param_count_ + << " which less than the sum of following: fv_param_count: " << fv_param_count_ << ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_ << ", kw_only_args_count: " << kw_only_args_count_; } size_t kw_only_args_start_offset = parameters_.size() - min_param_num; - std::copy(parameters_.cbegin() + kw_only_args_start_offset, - parameters_.cend() - hyper_param_count_ - varargs_kwargs_num, std::back_inserter(kw_only_args)); + std::copy(parameters_.cbegin() + kw_only_args_start_offset, parameters_.cend() - fv_param_count_ - varargs_kwargs_num, + std::back_inserter(kw_only_args)); return kw_only_args; } @@ -659,7 +661,7 @@ int FuncGraph::GetPositionalArgsCount() const { if (has_vararg_) { count--; } - return (count - kw_only_args_count_) - SizeToInt(hyper_param_count_); + return (count - kw_only_args_count_) - SizeToInt(fv_param_count_); } AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) { @@ -763,13 +765,6 @@ CNodePtr FuncGraph::NewCNodeInOrder(const PrimitivePtr &primitive, const std::ve return NewCNodeInOrder(std::move(input_node_list)); } -ParameterPtr FuncGraph::add_weight(const tensor::MetaTensorPtr &meta_tensor) { - auto parameter = add_parameter(); - parameter->set_default_param(MakeValue(meta_tensor)); - parameter->set_abstract(meta_tensor->ToAbstract()); - return parameter; -} - void FuncGraph::SetMultiTarget() { auto graph_manager = manager(); MS_EXCEPTION_IF_NULL(graph_manager); diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index ff7cdc6ccc1..049c9f7bab9 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -132,8 +132,8 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder { void PrependParameter(const ParameterPtr &p) { parameters_.insert(parameters_.begin(), p); } void set_parameters(const std::vector ¶ms) { parameters_ = params; } void set_parameters(std::vector &¶ms) { parameters_ = std::move(params); } - // Add a weight parameter with specific name. - ParameterPtr AddWeightParameter(const std::string &name); + // Add a FV weight parameter with specific name. + ParameterPtr AddFvParameter(const std::string &name, const ValuePtr &default_value); // Create a cnode with given inputs, bound to this graph. virtual CNodePtr NewCNode(std::vector &&inputs); @@ -154,7 +154,6 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder { // Create a cnode with given inputs, put it to order list after the position node. CNodePtr NewCNodeAfter(const AnfNodePtr &position, const std::vector &inputs); - virtual ParameterPtr add_weight(const tensor::MetaTensorPtr &meta_tensor); // Functions for handling variable argument, keyword-only arguments and variable keyword argument. AnfNodePtr GetDefaultValueByName(const std::string &name); void set_param_default_value(const std::string &name, const AnfNodePtr &node) { @@ -176,8 +175,8 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder { AnfNodePtr GetVariableKwargParameter(); std::string GetVariableKwargName(); AnfNodePtrList GetKwOnlyArgsParameters(); - void set_hyper_param_count(size_t count) { hyper_param_count_ = count; } - size_t hyper_param_count() const { return hyper_param_count_; } + void set_fv_param_count(size_t count) { fv_param_count_ = count; } + size_t fv_param_count() const { return fv_param_count_; } int GetPositionalArgsCount() const; AnfNodePtr GetParameterByName(const std::string &name); bool NeedGenerate(const std::vector &kwarg_list); @@ -418,9 +417,9 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder { bool has_kwarg_; bool exist_multi_target_; int kw_only_args_count_; - // Hyper param is placed on the top graph, + // Hyper param is used as free variable and placed on the top graph. // and positioned in the end of the param list, so we record the number to trace the position. - size_t hyper_param_count_; + size_t fv_param_count_; // Argument input list for the graph used to generate this graph. bool is_generated_; // CNode that calls 'return' primitive. diff --git a/mindspore/core/ir/func_graph_cloner.cc b/mindspore/core/ir/func_graph_cloner.cc index 2d381991975..9409c8f0961 100644 --- a/mindspore/core/ir/func_graph_cloner.cc +++ b/mindspore/core/ir/func_graph_cloner.cc @@ -256,7 +256,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, const FuncGraphPtr target_func_graph->set_has_vararg(func_graph->has_vararg()); target_func_graph->set_has_kwarg(func_graph->has_kwarg()); target_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count()); - target_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); + target_func_graph->set_fv_param_count(func_graph->fv_param_count()); target_func_graph->set_is_generate(func_graph->is_generated()); target_func_graph->set_stub(func_graph->stub()); target_func_graph->set_switch_input(func_graph->switch_input()); @@ -822,7 +822,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP new_func_graph->set_has_vararg(func_graph->has_vararg()); new_func_graph->set_has_kwarg(func_graph->has_kwarg()); new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count()); - new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); + new_func_graph->set_fv_param_count(func_graph->fv_param_count()); new_func_graph->set_is_generate(func_graph->is_generated()); new_func_graph->set_stub(func_graph->stub()); new_func_graph->set_switch_input(func_graph->switch_input()); diff --git a/mindspore/core/ir/func_graph_extends.cc b/mindspore/core/ir/func_graph_extends.cc index 8e0d46e8d4a..a54052f09ce 100644 --- a/mindspore/core/ir/func_graph_extends.cc +++ b/mindspore/core/ir/func_graph_extends.cc @@ -196,7 +196,7 @@ void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph, const std::vector &specialized_parameter_list, mindspore::HashMap *repl_nodes) const { MS_EXCEPTION_IF_NULL(specialized_graph); - for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) { + for (size_t i = 0; i < specialized_graph->parameters().size() - fv_param_count(); ++i) { MS_EXCEPTION_IF_NULL(specialized_graph->parameters()[i]); auto param_node = specialized_graph->parameters()[i]->cast(); MS_EXCEPTION_IF_NULL(param_node); @@ -222,10 +222,10 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) std::vector kwarg_list; std::vector pos_arg_indexes; size_t arguments_count = args_spec_list.size(); - if (hyper_param_count_ > arguments_count) { + if (fv_param_count_ > arguments_count) { MS_LOG(EXCEPTION) << "The number of parameters in funcgraph cannot exceed the number of arguments."; } - for (size_t i = 0; i < arguments_count - hyper_param_count_; i++) { + for (size_t i = 0; i < arguments_count - fv_param_count_; i++) { MS_EXCEPTION_IF_NULL(args_spec_list[i]); if (args_spec_list[i]->isa()) { kwarg_list.push_back(args_spec_list[i]->cast()); @@ -243,7 +243,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) } FuncGraphPtr specialized_graph = BasicClone(shared_from_base()); size_t kwarg_count = kwarg_list.size(); - int pos_args_input_count = SizeToInt((arguments_count - kwarg_count) - hyper_param_count_); + int pos_args_input_count = SizeToInt((arguments_count - kwarg_count) - fv_param_count_); int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount()); int variable_args_count = pos_args_input_count - pos_args_count; std::vector specialized_parameter_list; @@ -263,7 +263,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) // append hyper parameter to specialized_parameter_list MS_EXCEPTION_IF_NULL(specialized_graph); auto params = specialized_graph->parameters(); - specialized_parameter_list.insert(specialized_parameter_list.end(), params.end() - SizeToInt(hyper_param_count_), + specialized_parameter_list.insert(specialized_parameter_list.end(), params.end() - SizeToInt(fv_param_count_), params.end()); std::vector specialized_parameter_list_update(specialized_parameter_list.begin() + pos_arg_indexes.size(), specialized_parameter_list.end()); diff --git a/mindspore/core/load_mindir/anf_model_parser.cc b/mindspore/core/load_mindir/anf_model_parser.cc index e9c9a4c0b08..d692e566778 100644 --- a/mindspore/core/load_mindir/anf_model_parser.cc +++ b/mindspore/core/load_mindir/anf_model_parser.cc @@ -1488,7 +1488,7 @@ bool MSANFModelParser::MSANFParseModelConfigureInfo(const mind_ir::ModelProto &m bool MSANFModelParser::SetValueForTopGraphParameter(const FuncGraphPtr &topGraph, const std::map &weights) { - size_t hyper_param_count = 0; + size_t fv_param_count = 0; auto parameters = topGraph->parameters(); for (int i = parameters.size() - 1; i >= 0; --i) { size_t index = IntToSize(i); @@ -1512,9 +1512,9 @@ bool MSANFModelParser::SetValueForTopGraphParameter(const FuncGraphPtr &topGraph return false; } parameter->set_default_param(weights_iter->second); - hyper_param_count++; + fv_param_count++; } - topGraph->set_hyper_param_count(hyper_param_count); + topGraph->set_fv_param_count(fv_param_count); return true; } diff --git a/tests/st/construct_input/test_outermost_input.py b/tests/st/construct_input/test_outermost_input.py new file mode 100644 index 00000000000..84ab095d9d0 --- /dev/null +++ b/tests/st/construct_input/test_outermost_input.py @@ -0,0 +1,290 @@ +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test outermost net pass non_tensor inputs""" +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor, Parameter, ParameterTuple +from mindspore.ops import composite as C +from mindspore.ops import operations as P +import mindspore.ops as ops +from mindspore import context + + +@pytest.fixture(scope="module", autouse=True) +def setup_teardown(): + yield + context.set_context(mode=context.GRAPH_MODE) + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + + def construct(self, tensor_param_x, tuple_a, list_b, tensor_param_y, tensor_param_z, dict_c): + out = self.add(tensor_param_x, tuple_a[0]) + out = self.sub(out, list_b[1][1]["y"]) + out = self.add(out, tensor_param_y) + out = self.sub(out, tensor_param_z) + out = self.add(out, dict_c["u"]) + return out + + +class GradNet(nn.Cell): + def __init__(self, net, get_all): + super(GradNet, self).__init__() + self.forward_net = net + self.sens = Tensor(np.ones((2, 2), np.float32) * 5) + self.grad_all = C.GradOperation(get_all=get_all) + + def construct(self, tuple_a, tensor_param_x, list_b, tensor_param_y, tensor_param_z, dict_c): + return self.grad_all(self.forward_net)(tuple_a, tensor_param_x, list_b, tensor_param_y, tensor_param_z, dict_c) + + +tensor_x = Tensor(np.ones((2, 2), np.float32)) +tensor_y = Tensor(np.ones((2, 2), np.float32) * 2) +tensor_z = Tensor(np.ones((2, 2), np.float32) * 3) +tensor_w = Tensor(np.ones((2, 2), np.float32) * 4) +tensor_p = Tensor(np.ones((2, 2), np.float32) * 5) +tensor_u = Tensor(np.ones((2, 2), np.float32) * 6) +tuple_arg = (tensor_x, tensor_y, tensor_z, tensor_w) +list_arg = [[tensor_x, tensor_x], [[tensor_x, tensor_y], {"x": tensor_x, "y": tensor_y, "z": tensor_z, "p": tensor_p}]] +dict_arg = {"x": tensor_x, "y": tensor_y, "u": tensor_u} + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_non_tensor_inputs(mode): + """ + Feature: Construct()/ms_function input type with back propagate. + Description: Normal input type without tensor. + Expectation: No exception. + """ + context.set_context(mode=mode) + # grad first input + grad_fist_input_tensor_net = GradNet(Net(), get_all=False) + ret = grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_p, dict_arg) + assert np.allclose(ret.asnumpy(), np.ones((2, 2), np.float32)) + # grad all inputs + grad_all_input_tensor_net = GradNet(Net(), get_all=True) + ret_all = grad_all_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_p, dict_arg) + assert len(ret_all) == 3 + assert np.allclose(ret_all[0].asnumpy(), np.ones((2, 2), np.float32)) + assert np.allclose(ret_all[1].asnumpy(), np.ones((2, 2), np.float32)) + assert np.allclose(ret_all[2].asnumpy(), np.ones((2, 2), np.float32) * -1) + + +class GradNet1(nn.Cell): + def __init__(self, net, get_all): + super(GradNet1, self).__init__() + self.forward_net = net + self.sens = Tensor(np.ones((2, 2), np.float32) * 5) + self.grad_all = C.GradOperation(get_all=get_all) + + def construct(self, tuple_a, tensor_a, list_b, tensor_b, tensor_c, dict_c): + return self.grad_all(self.forward_net)(tuple_a, tensor_a, list_b, tensor_b, tensor_c, dict_c) + + +# PyNative run error. +# Support context.PYNATIVE_MODE later. +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE]) +def test_grad_first_input_net(mode): + """ + Feature: Construct()/ms_function input type with back propagate. + Description: Normal input type. + Expectation: No exception. + """ + class FirstInputTensorNet(nn.Cell): + def construct(self, tensor_a, tuple_a, list_b, tensor_b, tensor_c, dict_c): + return tensor_a + tuple_a[2] - list_b[1][1]["y"] + tensor_b - tensor_c + dict_c["y"] + + context.set_context(mode=mode) + grad_fist_input_tensor_net = GradNet1(FirstInputTensorNet(), get_all=False) + res = grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_y, dict_arg) + print('res:', res) + assert np.allclose(res.asnumpy(), np.ones((2, 2), np.float32)) + + +class TestCell(nn.Cell): + def __init__(self, param): + super().__init__() + self.a = Tensor(np.array([[1, 2], [3, 4]])) + self.param = param + + def construct(self, x): + return self.a * self.param * x + + +class GradCellWithParameter(nn.Cell): + def __init__(self, net): + super().__init__() + self.net = net + self.grad = ops.GradOperation(get_all=True, get_by_list=True) + self.param = self.net.param + + def construct(self, x): + return self.grad(self.net, self.param)(x) + + +class GradCell(nn.Cell): + def __init__(self, net): + super().__init__() + self.net = net + self.grad_all = ops.GradOperation(get_all=True) + + def construct(self, x): + return self.grad_all(self.net)(x) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_grad_parameter_input(mode): + """ + Feature: Construct()/ms_function input type with back propagate. + Description: Grad with Parameter as input type. + Expectation: No exception. + """ + context.set_context(mode=mode) + x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x') + y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y') + z = Tensor(np.array([[7, 8], [9, 0]])) + a = GradCell(TestCell(x))(y) + b = GradCell(TestCell(x))(z) + print(f'a: {a}') + print(f'b: {b}') + assert np.array_equal(a[0].asnumpy(), b[0].asnumpy()) + + +# PyNative run error. +# Support context.PYNATIVE_MODE later. +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE]) +def test_grad_parameter_as_input_and_fv(mode): + """ + Feature: Construct()/ms_function input type with back propagate. + Description: Grad with Parameters as input type and fv. + Expectation: No exception. + """ + context.set_context(mode=mode) + x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x') + y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y') + z = Tensor(np.array([[7, 8], [9, 0]])) + a = GradCellWithParameter(TestCell(x))(y) + b = GradCellWithParameter(TestCell(x))(z) + print(f'a: {a}') + print(f'b: {b}') + assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy()) + assert np.array_equal(a[1].asnumpy(), b[1].asnumpy()) + + +# PyNative run error. +# Support context.PYNATIVE_MODE later. +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE]) +def test_grad_same_parameter_both_input_and_fv(mode): + """ + Feature: Construct()/ms_function input type with back propagate. + Description: Grad with the same Parameter used as input type and fv at the same time. + Expectation: No exception. + """ + context.set_context(mode=mode) + x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x') + y = Tensor(np.array([[1, 2], [3, 4]])) + a = GradCellWithParameter(TestCell(x))(x) + b = GradCellWithParameter(TestCell(x))(y) + print(f'a: {a}') + print(f'b: {b}') + assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy()) + assert np.array_equal(a[1].asnumpy(), b[1].asnumpy()) + + +class TestCell2(nn.Cell): + def __init__(self, param1, param2): + super().__init__() + self.a = Tensor(np.array([[1, 2], [3, 4]])) + self.param1 = param1 + self.param2 = param2 + + def construct(self, x): + return self.a * self.param1 * self.param2 * x + + +class GradCellWithParameterTuple(nn.Cell): + def __init__(self, net): + super().__init__() + self.net = net + self.grad = ops.GradOperation(get_all=True, get_by_list=True) + self.param1 = self.net.param1 + self.param2 = self.net.param2 + self.params = ParameterTuple([self.param1, self.param2]) + + def construct(self, x): + return self.grad(self.net, self.params)(x) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_grad_parameter_as_input_and_fv2(mode): + """ + Feature: Construct()/ms_function input type with back propagate. + Description: Grad with Parameters as input type and fv. ParameterTuple as fv. + Expectation: No exception. + """ + context.set_context(mode=mode) + x1 = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x1') + x2 = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x2') + y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y') + z = Tensor(np.array([[7, 8], [9, 0]])) + a = GradCellWithParameterTuple(TestCell2(x1, x2))(y) + b = GradCellWithParameterTuple(TestCell2(x1, x2))(z) + print(f'a: {a}') + print(f'b: {b}') + assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy()) + assert np.array_equal(a[1][0].asnumpy(), b[1][0].asnumpy()) + assert np.array_equal(a[1][1].asnumpy(), b[1][1].asnumpy()) diff --git a/tests/st/pynative/ms_function/test_inner_dyn_shape_ms_function.py b/tests/st/ms_function/test_inner_dyn_shape_ms_function.py similarity index 100% rename from tests/st/pynative/ms_function/test_inner_dyn_shape_ms_function.py rename to tests/st/ms_function/test_inner_dyn_shape_ms_function.py diff --git a/tests/st/pynative/ms_function/test_nested_calling_ms_function.py b/tests/st/ms_function/test_nested_calling_ms_function.py similarity index 100% rename from tests/st/pynative/ms_function/test_nested_calling_ms_function.py rename to tests/st/ms_function/test_nested_calling_ms_function.py diff --git a/tests/st/pynative/ms_function/test_outmost_dyn_shape_ms_function.py b/tests/st/ms_function/test_outmost_dyn_shape_ms_function.py similarity index 100% rename from tests/st/pynative/ms_function/test_outmost_dyn_shape_ms_function.py rename to tests/st/ms_function/test_outmost_dyn_shape_ms_function.py diff --git a/tests/st/pynative/ms_function/test_pynative_lenet_ms_function.py b/tests/st/ms_function/test_pynative_lenet_ms_function.py similarity index 100% rename from tests/st/pynative/ms_function/test_pynative_lenet_ms_function.py rename to tests/st/ms_function/test_pynative_lenet_ms_function.py diff --git a/tests/st/pynative/ms_function/test_pynative_ms_function.py b/tests/st/ms_function/test_pynative_ms_function.py similarity index 100% rename from tests/st/pynative/ms_function/test_pynative_ms_function.py rename to tests/st/ms_function/test_pynative_ms_function.py diff --git a/tests/st/pynative/non_tensor_input/test_pynative_outermost_non_tensor.py b/tests/st/pynative/non_tensor_input/test_pynative_outermost_non_tensor.py deleted file mode 100644 index bba3652869f..00000000000 --- a/tests/st/pynative/non_tensor_input/test_pynative_outermost_non_tensor.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" test outermost net pass non_tensor inputs""" -import numpy as np -import pytest - -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops import composite as C -from mindspore.ops import operations as P -from mindspore import context - -context.set_context(mode=context.PYNATIVE_MODE) - - -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.add = P.TensorAdd() - self.sub = P.Sub() - - def construct(self, tensor_x, tuple_a, list_b, tensor_y, tensor_z, dict_c): - out = self.add(tensor_x, tuple_a[0]) - out = self.sub(out, list_b[1][1]["y"]) - out = self.add(out, tensor_y) - out = self.sub(out, tensor_z) - out = self.add(out, dict_c["u"]) - return out - - -class GradNet(nn.Cell): - def __init__(self, net, get_all): - super(GradNet, self).__init__() - self.forward_net = net - self.sens = Tensor(np.ones((2, 2), np.float32) * 5) - self.grad_all = C.GradOperation(get_all=get_all) - - def construct(self, tuple_a, tensor_x, list_b, tensor_y, tensor_z, dict_c): - return self.grad_all(self.forward_net)(tuple_a, tensor_x, list_b, tensor_y, tensor_z, dict_c) - - -x = Tensor(np.ones((2, 2), np.float32)) -y = Tensor(np.ones((2, 2), np.float32) * 2) -z = Tensor(np.ones((2, 2), np.float32) * 3) -w = Tensor(np.ones((2, 2), np.float32) * 4) -p = Tensor(np.ones((2, 2), np.float32) * 5) -u = Tensor(np.ones((2, 2), np.float32) * 6) -arg_t0 = (x, y, z, w) -arg_l0 = [[x, x], [[x, y], {"x": x, "y": y, "z": z, "p": p}]] -args_d0 = {"x": x, "y": y, "u": u} - - -@pytest.mark.level1 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_arm_ascend_training -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_non_tensor_inputs(): - # grad first input - grad_fist_input_tensor_net = GradNet(Net(), get_all=False) - ret = grad_fist_input_tensor_net(z, arg_t0, arg_l0, w, p, args_d0) - assert np.allclose(ret.asnumpy(), np.ones((2, 2), np.float32)) - # grad all inputs - grad_all_input_tensor_net = GradNet(Net(), get_all=True) - ret_all = grad_all_input_tensor_net(z, arg_t0, arg_l0, w, p, args_d0) - assert len(ret_all) == 3 - assert np.allclose(ret_all[0].asnumpy(), np.ones((2, 2), np.float32)) - assert np.allclose(ret_all[1].asnumpy(), np.ones((2, 2), np.float32)) - assert np.allclose(ret_all[2].asnumpy(), np.ones((2, 2), np.float32) * -1) diff --git a/tests/ut/cpp/pre_activate/ascend/dynamic_shape/dynamic_shape_pass_test.cc b/tests/ut/cpp/pre_activate/ascend/dynamic_shape/dynamic_shape_pass_test.cc index 10e08e69d4c..ce421a56cae 100644 --- a/tests/ut/cpp/pre_activate/ascend/dynamic_shape/dynamic_shape_pass_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/dynamic_shape/dynamic_shape_pass_test.cc @@ -36,7 +36,7 @@ constexpr auto kDependRealInputSize = 2; ParameterPtr TestCreateParameter(const KernelGraphPtr &g, const std::string &name, const abstract::AbstractBasePtr &abstract) { MS_EXCEPTION_IF_NULL(g); - auto parameter = g->AddWeightParameter(name); + auto parameter = g->AddFvParameter(name, abstract->BuildValue()); if (parameter == nullptr) { MS_LOG(ERROR) << "Cannot add weight parameter!"; } diff --git a/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py b/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py index d303f815dd8..72b3fad8fe9 100644 --- a/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py +++ b/tests/ut/python/pipeline/parse/test_outermost_net_pass_non_tensor_inputs.py @@ -19,9 +19,9 @@ import pytest import mindspore.nn as nn from mindspore.common import mutable from mindspore import Tensor, Parameter, ParameterTuple -from mindspore import context from mindspore.ops import composite as C import mindspore.ops as ops +from mindspore import context @pytest.fixture(scope="module", autouse=True) @@ -91,29 +91,7 @@ def test_grad_first_input_net(mode): context.set_context(mode=mode) grad_fist_input_tensor_net = GradNet1(FirstInputTensorNet(), get_all=False) - res = grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_y, dict_arg) - print('res:', res) - assert np.allclose(res.asnumpy(), np.ones((2, 2), np.float32)) - - -# PyNative run error. -# Support context.PYNATIVE_MODE later. -@pytest.mark.parametrize('mode', [context.GRAPH_MODE]) -def test_grad_first_input_net_pynative_error(mode): - """ - Feature: Construct()/ms_function input type with back propagate. - Description: Normal input type. - Expectation: No exception. - """ - class FirstInputTensorNet(nn.Cell): - def construct(self, tensor_a, tuple_a, list_b, tensor_b, tensor_c, dict_c): - return tensor_a + tuple_a[2] - list_b[1][1]["y"] + tensor_b - tensor_c + dict_c["y"] - - context.set_context(mode=mode) - grad_fist_input_tensor_net = GradNet1(FirstInputTensorNet(), get_all=False) - res = grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_y, dict_arg) - print('res:', res) - assert np.allclose(res.asnumpy(), np.ones((2, 2), np.float32)) + grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_y, dict_arg) @pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE]) @@ -149,7 +127,6 @@ def test_outermost_net_pass_parameter(mode): # Support the Parameter as outermost input. -# Support context.PYNATIVE_MODE UT later. @pytest.mark.parametrize('mode', [context.GRAPH_MODE]) def test_outermost_net_pass_tuple_including_parameter(mode): """ @@ -163,7 +140,6 @@ def test_outermost_net_pass_tuple_including_parameter(mode): # Support the Parameter as outermost input. -# Support context.PYNATIVE_MODE UT later. @pytest.mark.parametrize('mode', [context.GRAPH_MODE]) def test_outermost_net_pass_list_including_parameter(mode): """ @@ -177,7 +153,6 @@ def test_outermost_net_pass_list_including_parameter(mode): # Support the Parameter as outermost input. -# Support context.PYNATIVE_MODE UT later. @pytest.mark.parametrize('mode', [context.GRAPH_MODE]) def test_grad_net_pass_dict_including_parameter(mode): """ @@ -190,96 +165,6 @@ def test_grad_net_pass_dict_including_parameter(mode): forward_net(tuple_arg, tensor_z, list_arg, SCALAR_NUM, SCALAR_NUM, mutable_dict, flag_0) -class TestCell(nn.Cell): - def __init__(self, param): - super().__init__() - self.a = Tensor(np.array([[1, 2], [3, 4]])) - self.param = param - - def construct(self, x): - return self.a * self.param * x - - -class GradCellWithParameter(nn.Cell): - def __init__(self, net): - super().__init__() - self.net = net - self.grad = ops.GradOperation(get_all=True, get_by_list=True) - self.param = self.net.param - - def construct(self, x): - return self.grad(self.net, self.param)(x) - - -class GradCell(nn.Cell): - def __init__(self, net): - super().__init__() - self.net = net - self.grad_all = ops.GradOperation(get_all=True) - - def construct(self, x): - return self.grad_all(self.net)(x) - - -@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE]) -def test_grad_parameter_input(mode): - """ - Feature: Construct()/ms_function input type with back propagate. - Description: Grad with Parameter as input type. - Expectation: No exception. - """ - context.set_context(mode=mode) - x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x') - y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y') - z = Tensor(np.array([[7, 8], [9, 0]])) - a = GradCell(TestCell(x))(y) - b = GradCell(TestCell(x))(z) - print(f'a: {a}') - print(f'b: {b}') - assert np.array_equal(a[0].asnumpy(), b[0].asnumpy()) - - -# PyNative run error. -# Support context.PYNATIVE_MODE later. -@pytest.mark.parametrize('mode', [context.GRAPH_MODE]) -def test_grad_parameter_as_input_and_fv(mode): - """ - Feature: Construct()/ms_function input type with back propagate. - Description: Grad with Parameters as input type and fv. - Expectation: No exception. - """ - context.set_context(mode=mode) - x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x') - y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y') - z = Tensor(np.array([[7, 8], [9, 0]])) - a = GradCellWithParameter(TestCell(x))(y) - b = GradCellWithParameter(TestCell(x))(z) - print(f'a: {a}') - print(f'b: {b}') - assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy()) - assert np.array_equal(a[1].asnumpy(), b[1].asnumpy()) - - -# PyNative run error. -# Support context.PYNATIVE_MODE later. -@pytest.mark.parametrize('mode', [context.GRAPH_MODE]) -def test_grad_same_parameter_both_input_and_fv(mode): - """ - Feature: Construct()/ms_function input type with back propagate. - Description: Grad with the same Parameter used as input type and fv at the same time. - Expectation: No exception. - """ - context.set_context(mode=mode) - x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x') - y = Tensor(np.array([[1, 2], [3, 4]])) - a = GradCellWithParameter(TestCell(x))(x) - b = GradCellWithParameter(TestCell(x))(y) - print(f'a: {a}') - print(f'b: {b}') - assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy()) - assert np.array_equal(a[1].asnumpy(), b[1].asnumpy()) - - class TestCell2(nn.Cell): def __init__(self, param1, param2): super().__init__() @@ -329,7 +214,7 @@ class GradCellWithTupleOfParameter(nn.Cell): @pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE]) -def test_grad_parameter_as_input_and_fv2(mode): +def test_grad_parameter_tuple(mode): """ Feature: Construct()/ms_function input type with back propagate. Description: Grad with Parameters as input type and fv. ParameterTuple as fv. @@ -340,13 +225,8 @@ def test_grad_parameter_as_input_and_fv2(mode): x2 = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x2') y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y') z = Tensor(np.array([[7, 8], [9, 0]])) - a = GradCellWithParameterTuple(TestCell2(x1, x2))(y) - b = GradCellWithParameterTuple(TestCell2(x1, x2))(z) - print(f'a: {a}') - print(f'b: {b}') - assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy()) - assert np.array_equal(a[1][0].asnumpy(), b[1][0].asnumpy()) - assert np.array_equal(a[1][1].asnumpy(), b[1][1].asnumpy()) + GradCellWithParameterTuple(TestCell2(x1, x2))(y) + GradCellWithParameterTuple(TestCell2(x1, x2))(z) @pytest.mark.skip(reason='Not support list or tuple of parameters as GradOperation inputs by now')