!35174 Change the weight parameter to FV parameter in FuncGraph.

Merge pull request !35174 from 张清华/opt_parameter
This commit is contained in:
i-robot 2022-06-01 03:32:28 +00:00 committed by Gitee
commit 6c1ea8074e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
20 changed files with 372 additions and 307 deletions

View File

@ -338,13 +338,13 @@ void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager
}), }),
root_parameters.end()); root_parameters.end());
size_t remove_param_count = origin_param_count - root_parameters.size(); size_t remove_param_count = origin_param_count - root_parameters.size();
size_t hyper_param_count = root_graph->hyper_param_count(); size_t fv_param_count = root_graph->fv_param_count();
if (remove_param_count > hyper_param_count) { if (remove_param_count > fv_param_count) {
MS_LOG(ERROR) << "The number of deleted parameters cannot exceed the number of original parameters."; MS_LOG(ERROR) << "The number of deleted parameters cannot exceed the number of original parameters.";
return; return;
} }
hyper_param_count = hyper_param_count - remove_param_count; fv_param_count = fv_param_count - remove_param_count;
root_graph->set_hyper_param_count(hyper_param_count); root_graph->set_fv_param_count(fv_param_count);
manager->SetParameters(root_graph, root_parameters); manager->SetParameters(root_graph, root_parameters);
} }
} // namespace } // namespace

View File

@ -176,9 +176,9 @@ static inline void AdjustCallerArgs(const FuncGraphPtr &called, const CNodePtr &
// 2. The arguments in caller may be less than the formal parameters in called as some parameters can have // 2. The arguments in caller may be less than the formal parameters in called as some parameters can have
// default value. // default value.
if (!called->has_vararg() && if (!called->has_vararg() &&
caller->inputs().size() > (1 + called->GetPositionalArgsCount() + called->hyper_param_count())) { caller->inputs().size() > (1 + called->GetPositionalArgsCount() + called->fv_param_count())) {
size_t start_offset = called->GetPositionalArgsCount() + 1; size_t start_offset = called->GetPositionalArgsCount() + 1;
size_t end_offset = called->hyper_param_count(); size_t end_offset = called->fv_param_count();
new_args.erase(new_args.begin() + start_offset, new_args.end() - end_offset); new_args.erase(new_args.begin() + start_offset, new_args.end() - end_offset);
} }

View File

@ -52,9 +52,7 @@ ParamMap AddCacheParameters(const FuncGraphPtr &graph, const ParamSet &parameter
auto cache_name = ori_param_name + "_cache"; auto cache_name = ori_param_name + "_cache";
new_param_info->set_name(cache_name); new_param_info->set_name(cache_name);
new_tensor->set_param_info(new_param_info); new_tensor->set_param_info(new_param_info);
auto cache_param = graph->AddWeightParameter(cache_name); auto cache_param = graph->AddFvParameter(cache_name, new_tensor);
cache_param->set_default_param(MakeValue(new_tensor));
cache_param->set_abstract(new_tensor->ToAbstract());
cache_host_params_map[cache_param] = param; cache_host_params_map[cache_param] = param;
} }
} }
@ -260,10 +258,7 @@ AnfNodePtr InitHashMap(const FuncGraphPtr &func_graph, const int64_t host_size,
std::string hashmap_name = "cache_hashmap"; std::string hashmap_name = "cache_hashmap";
new_param_info->set_name(hashmap_name); new_param_info->set_name(hashmap_name);
new_tensor->set_param_info(new_param_info); new_tensor->set_param_info(new_param_info);
auto hashmap = func_graph->AddWeightParameter(hashmap_name); return func_graph->AddFvParameter(hashmap_name, new_tensor);
hashmap->set_default_param(MakeValue(new_tensor));
hashmap->set_abstract(new_tensor->ToAbstract());
return hashmap;
} }
AnfNodePtr InitStep(const FuncGraphPtr &func_graph, TypeId type_id) { AnfNodePtr InitStep(const FuncGraphPtr &func_graph, TypeId type_id) {
@ -273,10 +268,7 @@ AnfNodePtr InitStep(const FuncGraphPtr &func_graph, TypeId type_id) {
std::string step_name = "cache_step"; std::string step_name = "cache_step";
new_param_info->set_name(step_name); new_param_info->set_name(step_name);
new_tensor->set_param_info(new_param_info); new_tensor->set_param_info(new_param_info);
auto step = func_graph->AddWeightParameter(step_name); return func_graph->AddFvParameter(step_name, new_tensor);
step->set_default_param(MakeValue(new_tensor));
step->set_abstract(new_tensor->ToAbstract());
return step;
} }
AnfNodePtr CreateMapCacheIdx(const FuncGraphPtr &func_graph, const AnfNodePtr &indices, AnfNodePtr CreateMapCacheIdx(const FuncGraphPtr &func_graph, const AnfNodePtr &indices,
@ -540,11 +532,7 @@ AnfNodePtr CreateOutputNodeParam(const FuncGraphPtr &graph, const AnfNodePtr &or
auto new_param_name = name + "_pipe"; auto new_param_name = name + "_pipe";
new_param_info->set_name(new_param_name); new_param_info->set_name(new_param_name);
new_tensor->set_param_info(new_param_info); new_tensor->set_param_info(new_param_info);
auto new_param = graph->AddWeightParameter(new_param_name); return graph->AddFvParameter(new_param_name, new_tensor);
new_param->set_default_param(MakeValue(new_tensor));
auto abs_tensor = new_tensor->ToAbstract();
new_param->set_abstract(abs_tensor);
return new_param->cast<AnfNodePtr>();
} }
AnfMap CreateOtherPipeParams(const FuncGraphPtr &graph, const AnfSet &no_ref_params) { AnfMap CreateOtherPipeParams(const FuncGraphPtr &graph, const AnfSet &no_ref_params) {

View File

@ -1085,7 +1085,7 @@ void PipelineTransformer::ModifyParameterList() {
} }
} }
auto del_num = parameters.size() - parameter_list.size(); auto del_num = parameters.size() - parameter_list.size();
root_->set_hyper_param_count(root_->hyper_param_count() - del_num); root_->set_fv_param_count(root_->fv_param_count() - del_num);
manager_->SetParameters(root_, parameter_list); manager_->SetParameters(root_, parameter_list);
} }
} // namespace parallel } // namespace parallel

View File

@ -175,14 +175,9 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
} }
} }
if (para_node == nullptr) { if (para_node == nullptr) {
auto node = top_func_graph->AddWeightParameter(param_name);
auto value = py::cast<tensor::MetaTensorPtr>(obj); auto value = py::cast<tensor::MetaTensorPtr>(obj);
para_node = top_func_graph->AddFvParameter(param_name, value);
param_obj_ids.emplace_back(obj_id); param_obj_ids.emplace_back(obj_id);
node->set_default_param(value);
// Set abstract for parameter
auto abs = value->ToAbstract();
node->set_abstract(abs);
para_node = node;
MS_LOG(DEBUG) << "Created a new weight parameter for " << func_graph->ToString() MS_LOG(DEBUG) << "Created a new weight parameter for " << func_graph->ToString()
<< ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString(); << ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
} }
@ -224,8 +219,8 @@ void ConvertLoadedGraph(const FuncGraphPtr &func_graph, const ValuePtr &value) {
// Update top_graph // Update top_graph
top_graph->add_parameter(param_ptr); top_graph->add_parameter(param_ptr);
size_t hyper_param_count = top_graph->hyper_param_count(); size_t fv_param_count = top_graph->fv_param_count();
top_graph->set_hyper_param_count(hyper_param_count + 1); top_graph->set_fv_param_count(fv_param_count + 1);
} else { } else {
input_params.push_back(param_ptr); input_params.push_back(param_ptr);
} }

View File

@ -477,8 +477,8 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
auto func_graph = output->func_graph(); auto func_graph = output->func_graph();
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
auto params = func_graph->parameters(); auto params = func_graph->parameters();
if ((args.size() + func_graph->hyper_param_count()) != params.size()) { if ((args.size() + func_graph->fv_param_count()) != params.size()) {
MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->hyper_param_count() MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->fv_param_count()
<< " not equal to graph input size " << params.size() << ", let graph to be executed."; << " not equal to graph input size " << params.size() << ", let graph to be executed.";
} }
@ -487,9 +487,9 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters"; MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
} }
size_t index = it - params.cbegin(); size_t index = it - params.cbegin();
if (index >= args.size() + func_graph->hyper_param_count()) { if (index >= args.size() + func_graph->fv_param_count()) {
MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size() MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size()
<< " add Parameter count " << func_graph->hyper_param_count() << "."; << " add Parameter count " << func_graph->fv_param_count() << ".";
} }
if (index < args.size()) { if (index < args.size()) {
*ret_val = args[index]; *ret_val = args[index];

View File

@ -41,7 +41,7 @@ FuncGraph::FuncGraph(GraphDebugInfoPtr &&debug_info)
has_kwarg_(false), has_kwarg_(false),
exist_multi_target_(false), exist_multi_target_(false),
kw_only_args_count_(0), kw_only_args_count_(0),
hyper_param_count_(0), fv_param_count_(0),
is_generated_(false), is_generated_(false),
return_(nullptr), return_(nullptr),
manager_(), manager_(),
@ -91,54 +91,56 @@ const std::vector<AnfNodePtr> FuncGraph::get_inputs() const {
ParameterPtr FuncGraph::add_parameter() { ParameterPtr FuncGraph::add_parameter() {
FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>(); FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
ParameterPtr p = std::make_shared<Parameter>(this_func_graph); ParameterPtr param = std::make_shared<Parameter>(this_func_graph);
add_parameter(p); add_parameter(param);
return p; return param;
} }
ParameterPtr FuncGraph::add_parameter(NodeDebugInfoPtr &&debug_info) { ParameterPtr FuncGraph::add_parameter(NodeDebugInfoPtr &&debug_info) {
FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>(); FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
ParameterPtr p = std::make_shared<Parameter>(this_func_graph, std::move(debug_info)); ParameterPtr param = std::make_shared<Parameter>(this_func_graph, std::move(debug_info));
add_parameter(p); add_parameter(param);
return p; return param;
} }
void FuncGraph::add_parameter(const ParameterPtr &p) { void FuncGraph::add_parameter(const ParameterPtr &param) {
if (manager_.lock()) { if (manager_.lock()) {
manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p); manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), param);
} else { } else {
parameters_.push_back(p); parameters_.push_back(param);
} }
} }
ParameterPtr FuncGraph::InsertFrontParameter() { ParameterPtr FuncGraph::InsertFrontParameter() {
FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>(); FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
ParameterPtr p = std::make_shared<Parameter>(this_func_graph); ParameterPtr param = std::make_shared<Parameter>(this_func_graph);
InsertFrontParameter(p); InsertFrontParameter(param);
return p; return param;
} }
void FuncGraph::InsertFrontParameter(const ParameterPtr &p) { void FuncGraph::InsertFrontParameter(const ParameterPtr &param) {
if (manager_.lock()) { if (manager_.lock()) {
manager_.lock()->InsertFrontParameter(shared_from_base<FuncGraph>(), p); manager_.lock()->InsertFrontParameter(shared_from_base<FuncGraph>(), param);
} else { } else {
PrependParameter(p); PrependParameter(param);
} }
} }
ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { ParameterPtr FuncGraph::AddFvParameter(const std::string &name, const ValuePtr &default_value) {
FuncGraphPtr this_graph = shared_from_base<FuncGraph>(); FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
ParameterPtr p = std::make_shared<Parameter>(this_graph); ParameterPtr param = std::make_shared<Parameter>(this_graph);
p->set_name(name); param->set_name(name);
p->debug_info()->set_name(name); param->debug_info()->set_name(name);
MS_EXCEPTION_IF_NULL(default_value);
param->set_default_param(default_value);
param->set_abstract(default_value->ToAbstract());
if (manager_.lock()) { if (manager_.lock()) {
manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p); manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), param);
} else { } else {
parameters_.push_back(p); parameters_.push_back(param);
} }
hyper_param_count_++; ++fv_param_count_;
return p; return param;
} }
bool FuncGraph::has_flag(const std::string &key) const { bool FuncGraph::has_flag(const std::string &key) const {
@ -573,11 +575,11 @@ AnfNodePtr FuncGraph::GetVariableArgParameter() {
min_param_num += 1; min_param_num += 1;
} }
min_param_num += kw_only_args_count_; min_param_num += kw_only_args_count_;
min_param_num += hyper_param_count_; min_param_num += fv_param_count_;
if (parameters_.size() < min_param_num) { if (parameters_.size() < min_param_num) {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size()
<< " which less than the sum of following: hyper_param_count: " << hyper_param_count_ << " which less than the sum of following: fv_param_count: " << fv_param_count_
<< ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_ << ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_
<< ", kw_only_args_count_: " << kw_only_args_count_; << ", kw_only_args_count_: " << kw_only_args_count_;
} }
@ -598,22 +600,22 @@ std::string FuncGraph::GetVariableArgName() {
AnfNodePtr FuncGraph::GetVariableKwargParameter() { AnfNodePtr FuncGraph::GetVariableKwargParameter() {
if (has_kwarg_) { if (has_kwarg_) {
if (parameters_.size() < hyper_param_count_ + 1) { if (parameters_.size() < fv_param_count_ + 1) {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", fv_param_count is " << fv_param_count_
<< hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; << ", parameters is less than 1 + fv_param_count";
} }
return parameters_[(parameters_.size() - hyper_param_count_) - 1]; return parameters_[(parameters_.size() - fv_param_count_) - 1];
} }
return nullptr; return nullptr;
} }
std::string FuncGraph::GetVariableKwargName() { std::string FuncGraph::GetVariableKwargName() {
if (has_kwarg_) { if (has_kwarg_) {
if (parameters_.size() < hyper_param_count_ + 1) { if (parameters_.size() < fv_param_count_ + 1) {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", fv_param_count is " << fv_param_count_
<< hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; << ", parameters is less than 1 + fv_param_count";
} }
const auto &parameter = parameters_[(parameters_.size() - hyper_param_count_) - 1]->cast<ParameterPtr>(); const auto &parameter = parameters_[(parameters_.size() - fv_param_count_) - 1]->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(parameter);
return parameter->name(); return parameter->name();
} }
@ -637,17 +639,17 @@ AnfNodePtrList FuncGraph::GetKwOnlyArgsParameters() {
varargs_kwargs_num += 1; varargs_kwargs_num += 1;
} }
min_param_num += kw_only_args_count_; min_param_num += kw_only_args_count_;
min_param_num += hyper_param_count_; min_param_num += fv_param_count_;
if (parameters_.size() < min_param_num) { if (parameters_.size() < min_param_num) {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size()
<< " which less than the sum of following: hyper_param_count: " << hyper_param_count_ << " which less than the sum of following: fv_param_count: " << fv_param_count_
<< ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_ << ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_
<< ", kw_only_args_count: " << kw_only_args_count_; << ", kw_only_args_count: " << kw_only_args_count_;
} }
size_t kw_only_args_start_offset = parameters_.size() - min_param_num; size_t kw_only_args_start_offset = parameters_.size() - min_param_num;
std::copy(parameters_.cbegin() + kw_only_args_start_offset, std::copy(parameters_.cbegin() + kw_only_args_start_offset, parameters_.cend() - fv_param_count_ - varargs_kwargs_num,
parameters_.cend() - hyper_param_count_ - varargs_kwargs_num, std::back_inserter(kw_only_args)); std::back_inserter(kw_only_args));
return kw_only_args; return kw_only_args;
} }
@ -659,7 +661,7 @@ int FuncGraph::GetPositionalArgsCount() const {
if (has_vararg_) { if (has_vararg_) {
count--; count--;
} }
return (count - kw_only_args_count_) - SizeToInt(hyper_param_count_); return (count - kw_only_args_count_) - SizeToInt(fv_param_count_);
} }
AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) { AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) {
@ -763,13 +765,6 @@ CNodePtr FuncGraph::NewCNodeInOrder(const PrimitivePtr &primitive, const std::ve
return NewCNodeInOrder(std::move(input_node_list)); return NewCNodeInOrder(std::move(input_node_list));
} }
ParameterPtr FuncGraph::add_weight(const tensor::MetaTensorPtr &meta_tensor) {
auto parameter = add_parameter();
parameter->set_default_param(MakeValue(meta_tensor));
parameter->set_abstract(meta_tensor->ToAbstract());
return parameter;
}
void FuncGraph::SetMultiTarget() { void FuncGraph::SetMultiTarget() {
auto graph_manager = manager(); auto graph_manager = manager();
MS_EXCEPTION_IF_NULL(graph_manager); MS_EXCEPTION_IF_NULL(graph_manager);

View File

@ -132,8 +132,8 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
void PrependParameter(const ParameterPtr &p) { parameters_.insert(parameters_.begin(), p); } void PrependParameter(const ParameterPtr &p) { parameters_.insert(parameters_.begin(), p); }
void set_parameters(const std::vector<AnfNodePtr> &params) { parameters_ = params; } void set_parameters(const std::vector<AnfNodePtr> &params) { parameters_ = params; }
void set_parameters(std::vector<AnfNodePtr> &&params) { parameters_ = std::move(params); } void set_parameters(std::vector<AnfNodePtr> &&params) { parameters_ = std::move(params); }
// Add a weight parameter with specific name. // Add a FV weight parameter with specific name.
ParameterPtr AddWeightParameter(const std::string &name); ParameterPtr AddFvParameter(const std::string &name, const ValuePtr &default_value);
// Create a cnode with given inputs, bound to this graph. // Create a cnode with given inputs, bound to this graph.
virtual CNodePtr NewCNode(std::vector<AnfNodePtr> &&inputs); virtual CNodePtr NewCNode(std::vector<AnfNodePtr> &&inputs);
@ -154,7 +154,6 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
// Create a cnode with given inputs, put it to order list after the position node. // Create a cnode with given inputs, put it to order list after the position node.
CNodePtr NewCNodeAfter(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs); CNodePtr NewCNodeAfter(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs);
virtual ParameterPtr add_weight(const tensor::MetaTensorPtr &meta_tensor);
// Functions for handling variable argument, keyword-only arguments and variable keyword argument. // Functions for handling variable argument, keyword-only arguments and variable keyword argument.
AnfNodePtr GetDefaultValueByName(const std::string &name); AnfNodePtr GetDefaultValueByName(const std::string &name);
void set_param_default_value(const std::string &name, const AnfNodePtr &node) { void set_param_default_value(const std::string &name, const AnfNodePtr &node) {
@ -176,8 +175,8 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
AnfNodePtr GetVariableKwargParameter(); AnfNodePtr GetVariableKwargParameter();
std::string GetVariableKwargName(); std::string GetVariableKwargName();
AnfNodePtrList GetKwOnlyArgsParameters(); AnfNodePtrList GetKwOnlyArgsParameters();
void set_hyper_param_count(size_t count) { hyper_param_count_ = count; } void set_fv_param_count(size_t count) { fv_param_count_ = count; }
size_t hyper_param_count() const { return hyper_param_count_; } size_t fv_param_count() const { return fv_param_count_; }
int GetPositionalArgsCount() const; int GetPositionalArgsCount() const;
AnfNodePtr GetParameterByName(const std::string &name); AnfNodePtr GetParameterByName(const std::string &name);
bool NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list); bool NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list);
@ -418,9 +417,9 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
bool has_kwarg_; bool has_kwarg_;
bool exist_multi_target_; bool exist_multi_target_;
int kw_only_args_count_; int kw_only_args_count_;
// Hyper param is placed on the top graph, // Hyper param is used as free variable and placed on the top graph.
// and positioned in the end of the param list, so we record the number to trace the position. // and positioned in the end of the param list, so we record the number to trace the position.
size_t hyper_param_count_; size_t fv_param_count_;
// Argument input list for the graph used to generate this graph. // Argument input list for the graph used to generate this graph.
bool is_generated_; bool is_generated_;
// CNode that calls 'return' primitive. // CNode that calls 'return' primitive.

View File

@ -256,7 +256,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, const FuncGraphPtr
target_func_graph->set_has_vararg(func_graph->has_vararg()); target_func_graph->set_has_vararg(func_graph->has_vararg());
target_func_graph->set_has_kwarg(func_graph->has_kwarg()); target_func_graph->set_has_kwarg(func_graph->has_kwarg());
target_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count()); target_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
target_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); target_func_graph->set_fv_param_count(func_graph->fv_param_count());
target_func_graph->set_is_generate(func_graph->is_generated()); target_func_graph->set_is_generate(func_graph->is_generated());
target_func_graph->set_stub(func_graph->stub()); target_func_graph->set_stub(func_graph->stub());
target_func_graph->set_switch_input(func_graph->switch_input()); target_func_graph->set_switch_input(func_graph->switch_input());
@ -822,7 +822,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
new_func_graph->set_has_vararg(func_graph->has_vararg()); new_func_graph->set_has_vararg(func_graph->has_vararg());
new_func_graph->set_has_kwarg(func_graph->has_kwarg()); new_func_graph->set_has_kwarg(func_graph->has_kwarg());
new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count()); new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); new_func_graph->set_fv_param_count(func_graph->fv_param_count());
new_func_graph->set_is_generate(func_graph->is_generated()); new_func_graph->set_is_generate(func_graph->is_generated());
new_func_graph->set_stub(func_graph->stub()); new_func_graph->set_stub(func_graph->stub());
new_func_graph->set_switch_input(func_graph->switch_input()); new_func_graph->set_switch_input(func_graph->switch_input());

View File

@ -196,7 +196,7 @@ void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
const std::vector<AnfNodePtr> &specialized_parameter_list, const std::vector<AnfNodePtr> &specialized_parameter_list,
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const { mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
MS_EXCEPTION_IF_NULL(specialized_graph); MS_EXCEPTION_IF_NULL(specialized_graph);
for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) { for (size_t i = 0; i < specialized_graph->parameters().size() - fv_param_count(); ++i) {
MS_EXCEPTION_IF_NULL(specialized_graph->parameters()[i]); MS_EXCEPTION_IF_NULL(specialized_graph->parameters()[i]);
auto param_node = specialized_graph->parameters()[i]->cast<ParameterPtr>(); auto param_node = specialized_graph->parameters()[i]->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_node); MS_EXCEPTION_IF_NULL(param_node);
@ -222,10 +222,10 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
std::vector<abstract::AbstractKeywordArgPtr> kwarg_list; std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
std::vector<size_t> pos_arg_indexes; std::vector<size_t> pos_arg_indexes;
size_t arguments_count = args_spec_list.size(); size_t arguments_count = args_spec_list.size();
if (hyper_param_count_ > arguments_count) { if (fv_param_count_ > arguments_count) {
MS_LOG(EXCEPTION) << "The number of parameters in funcgraph cannot exceed the number of arguments."; MS_LOG(EXCEPTION) << "The number of parameters in funcgraph cannot exceed the number of arguments.";
} }
for (size_t i = 0; i < arguments_count - hyper_param_count_; i++) { for (size_t i = 0; i < arguments_count - fv_param_count_; i++) {
MS_EXCEPTION_IF_NULL(args_spec_list[i]); MS_EXCEPTION_IF_NULL(args_spec_list[i]);
if (args_spec_list[i]->isa<abstract::AbstractKeywordArg>()) { if (args_spec_list[i]->isa<abstract::AbstractKeywordArg>()) {
kwarg_list.push_back(args_spec_list[i]->cast<abstract::AbstractKeywordArgPtr>()); kwarg_list.push_back(args_spec_list[i]->cast<abstract::AbstractKeywordArgPtr>());
@ -243,7 +243,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
} }
FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>()); FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>());
size_t kwarg_count = kwarg_list.size(); size_t kwarg_count = kwarg_list.size();
int pos_args_input_count = SizeToInt((arguments_count - kwarg_count) - hyper_param_count_); int pos_args_input_count = SizeToInt((arguments_count - kwarg_count) - fv_param_count_);
int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount()); int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount());
int variable_args_count = pos_args_input_count - pos_args_count; int variable_args_count = pos_args_input_count - pos_args_count;
std::vector<AnfNodePtr> specialized_parameter_list; std::vector<AnfNodePtr> specialized_parameter_list;
@ -263,7 +263,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
// append hyper parameter to specialized_parameter_list // append hyper parameter to specialized_parameter_list
MS_EXCEPTION_IF_NULL(specialized_graph); MS_EXCEPTION_IF_NULL(specialized_graph);
auto params = specialized_graph->parameters(); auto params = specialized_graph->parameters();
specialized_parameter_list.insert(specialized_parameter_list.end(), params.end() - SizeToInt(hyper_param_count_), specialized_parameter_list.insert(specialized_parameter_list.end(), params.end() - SizeToInt(fv_param_count_),
params.end()); params.end());
std::vector<AnfNodePtr> specialized_parameter_list_update(specialized_parameter_list.begin() + pos_arg_indexes.size(), std::vector<AnfNodePtr> specialized_parameter_list_update(specialized_parameter_list.begin() + pos_arg_indexes.size(),
specialized_parameter_list.end()); specialized_parameter_list.end());

View File

@ -1512,7 +1512,7 @@ bool MSANFModelParser::MSANFParseModelConfigureInfo(const mind_ir::ModelProto &m
bool MSANFModelParser::SetValueForTopGraphParameter(const FuncGraphPtr &topGraph, bool MSANFModelParser::SetValueForTopGraphParameter(const FuncGraphPtr &topGraph,
const std::map<std::string, ValuePtr> &weights) { const std::map<std::string, ValuePtr> &weights) {
size_t hyper_param_count = 0; size_t fv_param_count = 0;
auto parameters = topGraph->parameters(); auto parameters = topGraph->parameters();
for (int i = parameters.size() - 1; i >= 0; --i) { for (int i = parameters.size() - 1; i >= 0; --i) {
size_t index = IntToSize(i); size_t index = IntToSize(i);
@ -1536,9 +1536,9 @@ bool MSANFModelParser::SetValueForTopGraphParameter(const FuncGraphPtr &topGraph
return false; return false;
} }
parameter->set_default_param(weights_iter->second); parameter->set_default_param(weights_iter->second);
hyper_param_count++; fv_param_count++;
} }
topGraph->set_hyper_param_count(hyper_param_count); topGraph->set_fv_param_count(fv_param_count);
return true; return true;
} }

View File

@ -0,0 +1,290 @@
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test outermost net pass non_tensor inputs"""
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor, Parameter, ParameterTuple
from mindspore.ops import composite as C
from mindspore.ops import operations as P
import mindspore.ops as ops
from mindspore import context
@pytest.fixture(scope="module", autouse=True)
def setup_teardown():
yield
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
def construct(self, tensor_param_x, tuple_a, list_b, tensor_param_y, tensor_param_z, dict_c):
out = self.add(tensor_param_x, tuple_a[0])
out = self.sub(out, list_b[1][1]["y"])
out = self.add(out, tensor_param_y)
out = self.sub(out, tensor_param_z)
out = self.add(out, dict_c["u"])
return out
class GradNet(nn.Cell):
def __init__(self, net, get_all):
super(GradNet, self).__init__()
self.forward_net = net
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
self.grad_all = C.GradOperation(get_all=get_all)
def construct(self, tuple_a, tensor_param_x, list_b, tensor_param_y, tensor_param_z, dict_c):
return self.grad_all(self.forward_net)(tuple_a, tensor_param_x, list_b, tensor_param_y, tensor_param_z, dict_c)
tensor_x = Tensor(np.ones((2, 2), np.float32))
tensor_y = Tensor(np.ones((2, 2), np.float32) * 2)
tensor_z = Tensor(np.ones((2, 2), np.float32) * 3)
tensor_w = Tensor(np.ones((2, 2), np.float32) * 4)
tensor_p = Tensor(np.ones((2, 2), np.float32) * 5)
tensor_u = Tensor(np.ones((2, 2), np.float32) * 6)
tuple_arg = (tensor_x, tensor_y, tensor_z, tensor_w)
list_arg = [[tensor_x, tensor_x], [[tensor_x, tensor_y], {"x": tensor_x, "y": tensor_y, "z": tensor_z, "p": tensor_p}]]
dict_arg = {"x": tensor_x, "y": tensor_y, "u": tensor_u}
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_non_tensor_inputs(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Normal input type without tensor.
Expectation: No exception.
"""
context.set_context(mode=mode)
# grad first input
grad_fist_input_tensor_net = GradNet(Net(), get_all=False)
ret = grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_p, dict_arg)
assert np.allclose(ret.asnumpy(), np.ones((2, 2), np.float32))
# grad all inputs
grad_all_input_tensor_net = GradNet(Net(), get_all=True)
ret_all = grad_all_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_p, dict_arg)
assert len(ret_all) == 3
assert np.allclose(ret_all[0].asnumpy(), np.ones((2, 2), np.float32))
assert np.allclose(ret_all[1].asnumpy(), np.ones((2, 2), np.float32))
assert np.allclose(ret_all[2].asnumpy(), np.ones((2, 2), np.float32) * -1)
class GradNet1(nn.Cell):
def __init__(self, net, get_all):
super(GradNet1, self).__init__()
self.forward_net = net
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
self.grad_all = C.GradOperation(get_all=get_all)
def construct(self, tuple_a, tensor_a, list_b, tensor_b, tensor_c, dict_c):
return self.grad_all(self.forward_net)(tuple_a, tensor_a, list_b, tensor_b, tensor_c, dict_c)
# PyNative run error.
# Support context.PYNATIVE_MODE later.
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_grad_first_input_net(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Normal input type.
Expectation: No exception.
"""
class FirstInputTensorNet(nn.Cell):
def construct(self, tensor_a, tuple_a, list_b, tensor_b, tensor_c, dict_c):
return tensor_a + tuple_a[2] - list_b[1][1]["y"] + tensor_b - tensor_c + dict_c["y"]
context.set_context(mode=mode)
grad_fist_input_tensor_net = GradNet1(FirstInputTensorNet(), get_all=False)
res = grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_y, dict_arg)
print('res:', res)
assert np.allclose(res.asnumpy(), np.ones((2, 2), np.float32))
class TestCell(nn.Cell):
def __init__(self, param):
super().__init__()
self.a = Tensor(np.array([[1, 2], [3, 4]]))
self.param = param
def construct(self, x):
return self.a * self.param * x
class GradCellWithParameter(nn.Cell):
def __init__(self, net):
super().__init__()
self.net = net
self.grad = ops.GradOperation(get_all=True, get_by_list=True)
self.param = self.net.param
def construct(self, x):
return self.grad(self.net, self.param)(x)
class GradCell(nn.Cell):
def __init__(self, net):
super().__init__()
self.net = net
self.grad_all = ops.GradOperation(get_all=True)
def construct(self, x):
return self.grad_all(self.net)(x)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_grad_parameter_input(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Grad with Parameter as input type.
Expectation: No exception.
"""
context.set_context(mode=mode)
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
z = Tensor(np.array([[7, 8], [9, 0]]))
a = GradCell(TestCell(x))(y)
b = GradCell(TestCell(x))(z)
print(f'a: {a}')
print(f'b: {b}')
assert np.array_equal(a[0].asnumpy(), b[0].asnumpy())
# PyNative run error.
# Support context.PYNATIVE_MODE later.
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_grad_parameter_as_input_and_fv(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Grad with Parameters as input type and fv.
Expectation: No exception.
"""
context.set_context(mode=mode)
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
z = Tensor(np.array([[7, 8], [9, 0]]))
a = GradCellWithParameter(TestCell(x))(y)
b = GradCellWithParameter(TestCell(x))(z)
print(f'a: {a}')
print(f'b: {b}')
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
assert np.array_equal(a[1].asnumpy(), b[1].asnumpy())
# PyNative run error.
# Support context.PYNATIVE_MODE later.
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_grad_same_parameter_both_input_and_fv(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Grad with the same Parameter used as input type and fv at the same time.
Expectation: No exception.
"""
context.set_context(mode=mode)
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
y = Tensor(np.array([[1, 2], [3, 4]]))
a = GradCellWithParameter(TestCell(x))(x)
b = GradCellWithParameter(TestCell(x))(y)
print(f'a: {a}')
print(f'b: {b}')
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
assert np.array_equal(a[1].asnumpy(), b[1].asnumpy())
class TestCell2(nn.Cell):
def __init__(self, param1, param2):
super().__init__()
self.a = Tensor(np.array([[1, 2], [3, 4]]))
self.param1 = param1
self.param2 = param2
def construct(self, x):
return self.a * self.param1 * self.param2 * x
class GradCellWithParameterTuple(nn.Cell):
def __init__(self, net):
super().__init__()
self.net = net
self.grad = ops.GradOperation(get_all=True, get_by_list=True)
self.param1 = self.net.param1
self.param2 = self.net.param2
self.params = ParameterTuple([self.param1, self.param2])
def construct(self, x):
return self.grad(self.net, self.params)(x)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_grad_parameter_as_input_and_fv2(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Grad with Parameters as input type and fv. ParameterTuple as fv.
Expectation: No exception.
"""
context.set_context(mode=mode)
x1 = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x1')
x2 = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x2')
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
z = Tensor(np.array([[7, 8], [9, 0]]))
a = GradCellWithParameterTuple(TestCell2(x1, x2))(y)
b = GradCellWithParameterTuple(TestCell2(x1, x2))(z)
print(f'a: {a}')
print(f'b: {b}')
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
assert np.array_equal(a[1][0].asnumpy(), b[1][0].asnumpy())
assert np.array_equal(a[1][1].asnumpy(), b[1][1].asnumpy())

View File

@ -1,82 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test outermost net pass non_tensor inputs"""
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore import context
context.set_context(mode=context.PYNATIVE_MODE)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
def construct(self, tensor_x, tuple_a, list_b, tensor_y, tensor_z, dict_c):
out = self.add(tensor_x, tuple_a[0])
out = self.sub(out, list_b[1][1]["y"])
out = self.add(out, tensor_y)
out = self.sub(out, tensor_z)
out = self.add(out, dict_c["u"])
return out
class GradNet(nn.Cell):
def __init__(self, net, get_all):
super(GradNet, self).__init__()
self.forward_net = net
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
self.grad_all = C.GradOperation(get_all=get_all)
def construct(self, tuple_a, tensor_x, list_b, tensor_y, tensor_z, dict_c):
return self.grad_all(self.forward_net)(tuple_a, tensor_x, list_b, tensor_y, tensor_z, dict_c)
x = Tensor(np.ones((2, 2), np.float32))
y = Tensor(np.ones((2, 2), np.float32) * 2)
z = Tensor(np.ones((2, 2), np.float32) * 3)
w = Tensor(np.ones((2, 2), np.float32) * 4)
p = Tensor(np.ones((2, 2), np.float32) * 5)
u = Tensor(np.ones((2, 2), np.float32) * 6)
arg_t0 = (x, y, z, w)
arg_l0 = [[x, x], [[x, y], {"x": x, "y": y, "z": z, "p": p}]]
args_d0 = {"x": x, "y": y, "u": u}
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_non_tensor_inputs():
# grad first input
grad_fist_input_tensor_net = GradNet(Net(), get_all=False)
ret = grad_fist_input_tensor_net(z, arg_t0, arg_l0, w, p, args_d0)
assert np.allclose(ret.asnumpy(), np.ones((2, 2), np.float32))
# grad all inputs
grad_all_input_tensor_net = GradNet(Net(), get_all=True)
ret_all = grad_all_input_tensor_net(z, arg_t0, arg_l0, w, p, args_d0)
assert len(ret_all) == 3
assert np.allclose(ret_all[0].asnumpy(), np.ones((2, 2), np.float32))
assert np.allclose(ret_all[1].asnumpy(), np.ones((2, 2), np.float32))
assert np.allclose(ret_all[2].asnumpy(), np.ones((2, 2), np.float32) * -1)

View File

@ -36,7 +36,7 @@ constexpr auto kDependRealInputSize = 2;
ParameterPtr TestCreateParameter(const KernelGraphPtr &g, const std::string &name, ParameterPtr TestCreateParameter(const KernelGraphPtr &g, const std::string &name,
const abstract::AbstractBasePtr &abstract) { const abstract::AbstractBasePtr &abstract) {
MS_EXCEPTION_IF_NULL(g); MS_EXCEPTION_IF_NULL(g);
auto parameter = g->AddWeightParameter(name); auto parameter = g->AddFvParameter(name, abstract->BuildValue());
if (parameter == nullptr) { if (parameter == nullptr) {
MS_LOG(ERROR) << "Cannot add weight parameter!"; MS_LOG(ERROR) << "Cannot add weight parameter!";
} }

View File

@ -19,9 +19,9 @@ import pytest
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common import mutable from mindspore.common import mutable
from mindspore import Tensor, Parameter, ParameterTuple from mindspore import Tensor, Parameter, ParameterTuple
from mindspore import context
from mindspore.ops import composite as C from mindspore.ops import composite as C
import mindspore.ops as ops import mindspore.ops as ops
from mindspore import context
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(scope="module", autouse=True)
@ -91,29 +91,7 @@ def test_grad_first_input_net(mode):
context.set_context(mode=mode) context.set_context(mode=mode)
grad_fist_input_tensor_net = GradNet1(FirstInputTensorNet(), get_all=False) grad_fist_input_tensor_net = GradNet1(FirstInputTensorNet(), get_all=False)
res = grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_y, dict_arg) grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_y, dict_arg)
print('res:', res)
assert np.allclose(res.asnumpy(), np.ones((2, 2), np.float32))
# PyNative run error.
# Support context.PYNATIVE_MODE later.
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_grad_first_input_net_pynative_error(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Normal input type.
Expectation: No exception.
"""
class FirstInputTensorNet(nn.Cell):
def construct(self, tensor_a, tuple_a, list_b, tensor_b, tensor_c, dict_c):
return tensor_a + tuple_a[2] - list_b[1][1]["y"] + tensor_b - tensor_c + dict_c["y"]
context.set_context(mode=mode)
grad_fist_input_tensor_net = GradNet1(FirstInputTensorNet(), get_all=False)
res = grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_y, dict_arg)
print('res:', res)
assert np.allclose(res.asnumpy(), np.ones((2, 2), np.float32))
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE]) @pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
@ -149,7 +127,6 @@ def test_outermost_net_pass_parameter(mode):
# Support the Parameter as outermost input. # Support the Parameter as outermost input.
# Support context.PYNATIVE_MODE UT later.
@pytest.mark.parametrize('mode', [context.GRAPH_MODE]) @pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_outermost_net_pass_tuple_including_parameter(mode): def test_outermost_net_pass_tuple_including_parameter(mode):
""" """
@ -163,7 +140,6 @@ def test_outermost_net_pass_tuple_including_parameter(mode):
# Support the Parameter as outermost input. # Support the Parameter as outermost input.
# Support context.PYNATIVE_MODE UT later.
@pytest.mark.parametrize('mode', [context.GRAPH_MODE]) @pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_outermost_net_pass_list_including_parameter(mode): def test_outermost_net_pass_list_including_parameter(mode):
""" """
@ -177,7 +153,6 @@ def test_outermost_net_pass_list_including_parameter(mode):
# Support the Parameter as outermost input. # Support the Parameter as outermost input.
# Support context.PYNATIVE_MODE UT later.
@pytest.mark.parametrize('mode', [context.GRAPH_MODE]) @pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_grad_net_pass_dict_including_parameter(mode): def test_grad_net_pass_dict_including_parameter(mode):
""" """
@ -190,96 +165,6 @@ def test_grad_net_pass_dict_including_parameter(mode):
forward_net(tuple_arg, tensor_z, list_arg, SCALAR_NUM, SCALAR_NUM, mutable_dict, flag_0) forward_net(tuple_arg, tensor_z, list_arg, SCALAR_NUM, SCALAR_NUM, mutable_dict, flag_0)
class TestCell(nn.Cell):
def __init__(self, param):
super().__init__()
self.a = Tensor(np.array([[1, 2], [3, 4]]))
self.param = param
def construct(self, x):
return self.a * self.param * x
class GradCellWithParameter(nn.Cell):
def __init__(self, net):
super().__init__()
self.net = net
self.grad = ops.GradOperation(get_all=True, get_by_list=True)
self.param = self.net.param
def construct(self, x):
return self.grad(self.net, self.param)(x)
class GradCell(nn.Cell):
def __init__(self, net):
super().__init__()
self.net = net
self.grad_all = ops.GradOperation(get_all=True)
def construct(self, x):
return self.grad_all(self.net)(x)
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
def test_grad_parameter_input(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Grad with Parameter as input type.
Expectation: No exception.
"""
context.set_context(mode=mode)
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
z = Tensor(np.array([[7, 8], [9, 0]]))
a = GradCell(TestCell(x))(y)
b = GradCell(TestCell(x))(z)
print(f'a: {a}')
print(f'b: {b}')
assert np.array_equal(a[0].asnumpy(), b[0].asnumpy())
# PyNative run error.
# Support context.PYNATIVE_MODE later.
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_grad_parameter_as_input_and_fv(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Grad with Parameters as input type and fv.
Expectation: No exception.
"""
context.set_context(mode=mode)
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
z = Tensor(np.array([[7, 8], [9, 0]]))
a = GradCellWithParameter(TestCell(x))(y)
b = GradCellWithParameter(TestCell(x))(z)
print(f'a: {a}')
print(f'b: {b}')
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
assert np.array_equal(a[1].asnumpy(), b[1].asnumpy())
# PyNative run error.
# Support context.PYNATIVE_MODE later.
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_grad_same_parameter_both_input_and_fv(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Grad with the same Parameter used as input type and fv at the same time.
Expectation: No exception.
"""
context.set_context(mode=mode)
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
y = Tensor(np.array([[1, 2], [3, 4]]))
a = GradCellWithParameter(TestCell(x))(x)
b = GradCellWithParameter(TestCell(x))(y)
print(f'a: {a}')
print(f'b: {b}')
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
assert np.array_equal(a[1].asnumpy(), b[1].asnumpy())
class TestCell2(nn.Cell): class TestCell2(nn.Cell):
def __init__(self, param1, param2): def __init__(self, param1, param2):
super().__init__() super().__init__()
@ -329,7 +214,7 @@ class GradCellWithTupleOfParameter(nn.Cell):
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE]) @pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
def test_grad_parameter_as_input_and_fv2(mode): def test_grad_parameter_tuple(mode):
""" """
Feature: Construct()/ms_function input type with back propagate. Feature: Construct()/ms_function input type with back propagate.
Description: Grad with Parameters as input type and fv. ParameterTuple as fv. Description: Grad with Parameters as input type and fv. ParameterTuple as fv.
@ -340,13 +225,8 @@ def test_grad_parameter_as_input_and_fv2(mode):
x2 = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x2') x2 = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x2')
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y') y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
z = Tensor(np.array([[7, 8], [9, 0]])) z = Tensor(np.array([[7, 8], [9, 0]]))
a = GradCellWithParameterTuple(TestCell2(x1, x2))(y) GradCellWithParameterTuple(TestCell2(x1, x2))(y)
b = GradCellWithParameterTuple(TestCell2(x1, x2))(z) GradCellWithParameterTuple(TestCell2(x1, x2))(z)
print(f'a: {a}')
print(f'b: {b}')
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
assert np.array_equal(a[1][0].asnumpy(), b[1][0].asnumpy())
assert np.array_equal(a[1][1].asnumpy(), b[1][1].asnumpy())
@pytest.mark.skip(reason='Not support list or tuple of parameters as GradOperation inputs by now') @pytest.mark.skip(reason='Not support list or tuple of parameters as GradOperation inputs by now')