!35174 Change the weight parameter to FV parameter in FuncGraph.

Merge pull request !35174 from 张清华/opt_parameter
This commit is contained in:
i-robot 2022-06-01 03:32:28 +00:00 committed by Gitee
commit 6c1ea8074e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
20 changed files with 372 additions and 307 deletions

View File

@ -338,13 +338,13 @@ void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager
}),
root_parameters.end());
size_t remove_param_count = origin_param_count - root_parameters.size();
size_t hyper_param_count = root_graph->hyper_param_count();
if (remove_param_count > hyper_param_count) {
size_t fv_param_count = root_graph->fv_param_count();
if (remove_param_count > fv_param_count) {
MS_LOG(ERROR) << "The number of deleted parameters cannot exceed the number of original parameters.";
return;
}
hyper_param_count = hyper_param_count - remove_param_count;
root_graph->set_hyper_param_count(hyper_param_count);
fv_param_count = fv_param_count - remove_param_count;
root_graph->set_fv_param_count(fv_param_count);
manager->SetParameters(root_graph, root_parameters);
}
} // namespace

View File

@ -176,9 +176,9 @@ static inline void AdjustCallerArgs(const FuncGraphPtr &called, const CNodePtr &
// 2. The arguments in caller may be less than the formal parameters in called as some parameters can have
// default value.
if (!called->has_vararg() &&
caller->inputs().size() > (1 + called->GetPositionalArgsCount() + called->hyper_param_count())) {
caller->inputs().size() > (1 + called->GetPositionalArgsCount() + called->fv_param_count())) {
size_t start_offset = called->GetPositionalArgsCount() + 1;
size_t end_offset = called->hyper_param_count();
size_t end_offset = called->fv_param_count();
new_args.erase(new_args.begin() + start_offset, new_args.end() - end_offset);
}

View File

@ -52,9 +52,7 @@ ParamMap AddCacheParameters(const FuncGraphPtr &graph, const ParamSet &parameter
auto cache_name = ori_param_name + "_cache";
new_param_info->set_name(cache_name);
new_tensor->set_param_info(new_param_info);
auto cache_param = graph->AddWeightParameter(cache_name);
cache_param->set_default_param(MakeValue(new_tensor));
cache_param->set_abstract(new_tensor->ToAbstract());
auto cache_param = graph->AddFvParameter(cache_name, new_tensor);
cache_host_params_map[cache_param] = param;
}
}
@ -260,10 +258,7 @@ AnfNodePtr InitHashMap(const FuncGraphPtr &func_graph, const int64_t host_size,
std::string hashmap_name = "cache_hashmap";
new_param_info->set_name(hashmap_name);
new_tensor->set_param_info(new_param_info);
auto hashmap = func_graph->AddWeightParameter(hashmap_name);
hashmap->set_default_param(MakeValue(new_tensor));
hashmap->set_abstract(new_tensor->ToAbstract());
return hashmap;
return func_graph->AddFvParameter(hashmap_name, new_tensor);
}
AnfNodePtr InitStep(const FuncGraphPtr &func_graph, TypeId type_id) {
@ -273,10 +268,7 @@ AnfNodePtr InitStep(const FuncGraphPtr &func_graph, TypeId type_id) {
std::string step_name = "cache_step";
new_param_info->set_name(step_name);
new_tensor->set_param_info(new_param_info);
auto step = func_graph->AddWeightParameter(step_name);
step->set_default_param(MakeValue(new_tensor));
step->set_abstract(new_tensor->ToAbstract());
return step;
return func_graph->AddFvParameter(step_name, new_tensor);
}
AnfNodePtr CreateMapCacheIdx(const FuncGraphPtr &func_graph, const AnfNodePtr &indices,
@ -540,11 +532,7 @@ AnfNodePtr CreateOutputNodeParam(const FuncGraphPtr &graph, const AnfNodePtr &or
auto new_param_name = name + "_pipe";
new_param_info->set_name(new_param_name);
new_tensor->set_param_info(new_param_info);
auto new_param = graph->AddWeightParameter(new_param_name);
new_param->set_default_param(MakeValue(new_tensor));
auto abs_tensor = new_tensor->ToAbstract();
new_param->set_abstract(abs_tensor);
return new_param->cast<AnfNodePtr>();
return graph->AddFvParameter(new_param_name, new_tensor);
}
AnfMap CreateOtherPipeParams(const FuncGraphPtr &graph, const AnfSet &no_ref_params) {

View File

@ -1085,7 +1085,7 @@ void PipelineTransformer::ModifyParameterList() {
}
}
auto del_num = parameters.size() - parameter_list.size();
root_->set_hyper_param_count(root_->hyper_param_count() - del_num);
root_->set_fv_param_count(root_->fv_param_count() - del_num);
manager_->SetParameters(root_, parameter_list);
}
} // namespace parallel

View File

@ -175,14 +175,9 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
}
}
if (para_node == nullptr) {
auto node = top_func_graph->AddWeightParameter(param_name);
auto value = py::cast<tensor::MetaTensorPtr>(obj);
para_node = top_func_graph->AddFvParameter(param_name, value);
param_obj_ids.emplace_back(obj_id);
node->set_default_param(value);
// Set abstract for parameter
auto abs = value->ToAbstract();
node->set_abstract(abs);
para_node = node;
MS_LOG(DEBUG) << "Created a new weight parameter for " << func_graph->ToString()
<< ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
}
@ -224,8 +219,8 @@ void ConvertLoadedGraph(const FuncGraphPtr &func_graph, const ValuePtr &value) {
// Update top_graph
top_graph->add_parameter(param_ptr);
size_t hyper_param_count = top_graph->hyper_param_count();
top_graph->set_hyper_param_count(hyper_param_count + 1);
size_t fv_param_count = top_graph->fv_param_count();
top_graph->set_fv_param_count(fv_param_count + 1);
} else {
input_params.push_back(param_ptr);
}

View File

@ -477,8 +477,8 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
auto func_graph = output->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto params = func_graph->parameters();
if ((args.size() + func_graph->hyper_param_count()) != params.size()) {
MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->hyper_param_count()
if ((args.size() + func_graph->fv_param_count()) != params.size()) {
MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->fv_param_count()
<< " not equal to graph input size " << params.size() << ", let graph to be executed.";
}
@ -487,9 +487,9 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
}
size_t index = it - params.cbegin();
if (index >= args.size() + func_graph->hyper_param_count()) {
if (index >= args.size() + func_graph->fv_param_count()) {
MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size()
<< " add Parameter count " << func_graph->hyper_param_count() << ".";
<< " add Parameter count " << func_graph->fv_param_count() << ".";
}
if (index < args.size()) {
*ret_val = args[index];

View File

@ -41,7 +41,7 @@ FuncGraph::FuncGraph(GraphDebugInfoPtr &&debug_info)
has_kwarg_(false),
exist_multi_target_(false),
kw_only_args_count_(0),
hyper_param_count_(0),
fv_param_count_(0),
is_generated_(false),
return_(nullptr),
manager_(),
@ -91,54 +91,56 @@ const std::vector<AnfNodePtr> FuncGraph::get_inputs() const {
ParameterPtr FuncGraph::add_parameter() {
FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
ParameterPtr p = std::make_shared<Parameter>(this_func_graph);
add_parameter(p);
return p;
ParameterPtr param = std::make_shared<Parameter>(this_func_graph);
add_parameter(param);
return param;
}
ParameterPtr FuncGraph::add_parameter(NodeDebugInfoPtr &&debug_info) {
FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
ParameterPtr p = std::make_shared<Parameter>(this_func_graph, std::move(debug_info));
add_parameter(p);
return p;
ParameterPtr param = std::make_shared<Parameter>(this_func_graph, std::move(debug_info));
add_parameter(param);
return param;
}
void FuncGraph::add_parameter(const ParameterPtr &p) {
void FuncGraph::add_parameter(const ParameterPtr &param) {
if (manager_.lock()) {
manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p);
manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), param);
} else {
parameters_.push_back(p);
parameters_.push_back(param);
}
}
ParameterPtr FuncGraph::InsertFrontParameter() {
FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
ParameterPtr p = std::make_shared<Parameter>(this_func_graph);
InsertFrontParameter(p);
return p;
ParameterPtr param = std::make_shared<Parameter>(this_func_graph);
InsertFrontParameter(param);
return param;
}
void FuncGraph::InsertFrontParameter(const ParameterPtr &p) {
void FuncGraph::InsertFrontParameter(const ParameterPtr &param) {
if (manager_.lock()) {
manager_.lock()->InsertFrontParameter(shared_from_base<FuncGraph>(), p);
manager_.lock()->InsertFrontParameter(shared_from_base<FuncGraph>(), param);
} else {
PrependParameter(p);
PrependParameter(param);
}
}
ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
ParameterPtr FuncGraph::AddFvParameter(const std::string &name, const ValuePtr &default_value) {
FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
ParameterPtr p = std::make_shared<Parameter>(this_graph);
p->set_name(name);
p->debug_info()->set_name(name);
ParameterPtr param = std::make_shared<Parameter>(this_graph);
param->set_name(name);
param->debug_info()->set_name(name);
MS_EXCEPTION_IF_NULL(default_value);
param->set_default_param(default_value);
param->set_abstract(default_value->ToAbstract());
if (manager_.lock()) {
manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p);
manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), param);
} else {
parameters_.push_back(p);
parameters_.push_back(param);
}
hyper_param_count_++;
return p;
++fv_param_count_;
return param;
}
bool FuncGraph::has_flag(const std::string &key) const {
@ -573,11 +575,11 @@ AnfNodePtr FuncGraph::GetVariableArgParameter() {
min_param_num += 1;
}
min_param_num += kw_only_args_count_;
min_param_num += hyper_param_count_;
min_param_num += fv_param_count_;
if (parameters_.size() < min_param_num) {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size()
<< " which less than the sum of following: hyper_param_count: " << hyper_param_count_
<< " which less than the sum of following: fv_param_count: " << fv_param_count_
<< ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_
<< ", kw_only_args_count_: " << kw_only_args_count_;
}
@ -598,22 +600,22 @@ std::string FuncGraph::GetVariableArgName() {
AnfNodePtr FuncGraph::GetVariableKwargParameter() {
if (has_kwarg_) {
if (parameters_.size() < hyper_param_count_ + 1) {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
<< hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
if (parameters_.size() < fv_param_count_ + 1) {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", fv_param_count is " << fv_param_count_
<< ", parameters is less than 1 + fv_param_count";
}
return parameters_[(parameters_.size() - hyper_param_count_) - 1];
return parameters_[(parameters_.size() - fv_param_count_) - 1];
}
return nullptr;
}
std::string FuncGraph::GetVariableKwargName() {
if (has_kwarg_) {
if (parameters_.size() < hyper_param_count_ + 1) {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
<< hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
if (parameters_.size() < fv_param_count_ + 1) {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", fv_param_count is " << fv_param_count_
<< ", parameters is less than 1 + fv_param_count";
}
const auto &parameter = parameters_[(parameters_.size() - hyper_param_count_) - 1]->cast<ParameterPtr>();
const auto &parameter = parameters_[(parameters_.size() - fv_param_count_) - 1]->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter);
return parameter->name();
}
@ -637,17 +639,17 @@ AnfNodePtrList FuncGraph::GetKwOnlyArgsParameters() {
varargs_kwargs_num += 1;
}
min_param_num += kw_only_args_count_;
min_param_num += hyper_param_count_;
min_param_num += fv_param_count_;
if (parameters_.size() < min_param_num) {
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size()
<< " which less than the sum of following: hyper_param_count: " << hyper_param_count_
<< " which less than the sum of following: fv_param_count: " << fv_param_count_
<< ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_
<< ", kw_only_args_count: " << kw_only_args_count_;
}
size_t kw_only_args_start_offset = parameters_.size() - min_param_num;
std::copy(parameters_.cbegin() + kw_only_args_start_offset,
parameters_.cend() - hyper_param_count_ - varargs_kwargs_num, std::back_inserter(kw_only_args));
std::copy(parameters_.cbegin() + kw_only_args_start_offset, parameters_.cend() - fv_param_count_ - varargs_kwargs_num,
std::back_inserter(kw_only_args));
return kw_only_args;
}
@ -659,7 +661,7 @@ int FuncGraph::GetPositionalArgsCount() const {
if (has_vararg_) {
count--;
}
return (count - kw_only_args_count_) - SizeToInt(hyper_param_count_);
return (count - kw_only_args_count_) - SizeToInt(fv_param_count_);
}
AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) {
@ -763,13 +765,6 @@ CNodePtr FuncGraph::NewCNodeInOrder(const PrimitivePtr &primitive, const std::ve
return NewCNodeInOrder(std::move(input_node_list));
}
ParameterPtr FuncGraph::add_weight(const tensor::MetaTensorPtr &meta_tensor) {
auto parameter = add_parameter();
parameter->set_default_param(MakeValue(meta_tensor));
parameter->set_abstract(meta_tensor->ToAbstract());
return parameter;
}
void FuncGraph::SetMultiTarget() {
auto graph_manager = manager();
MS_EXCEPTION_IF_NULL(graph_manager);

View File

@ -132,8 +132,8 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
void PrependParameter(const ParameterPtr &p) { parameters_.insert(parameters_.begin(), p); }
void set_parameters(const std::vector<AnfNodePtr> &params) { parameters_ = params; }
void set_parameters(std::vector<AnfNodePtr> &&params) { parameters_ = std::move(params); }
// Add a weight parameter with specific name.
ParameterPtr AddWeightParameter(const std::string &name);
// Add a FV weight parameter with specific name.
ParameterPtr AddFvParameter(const std::string &name, const ValuePtr &default_value);
// Create a cnode with given inputs, bound to this graph.
virtual CNodePtr NewCNode(std::vector<AnfNodePtr> &&inputs);
@ -154,7 +154,6 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
// Create a cnode with given inputs, put it to order list after the position node.
CNodePtr NewCNodeAfter(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs);
virtual ParameterPtr add_weight(const tensor::MetaTensorPtr &meta_tensor);
// Functions for handling variable argument, keyword-only arguments and variable keyword argument.
AnfNodePtr GetDefaultValueByName(const std::string &name);
void set_param_default_value(const std::string &name, const AnfNodePtr &node) {
@ -176,8 +175,8 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
AnfNodePtr GetVariableKwargParameter();
std::string GetVariableKwargName();
AnfNodePtrList GetKwOnlyArgsParameters();
void set_hyper_param_count(size_t count) { hyper_param_count_ = count; }
size_t hyper_param_count() const { return hyper_param_count_; }
void set_fv_param_count(size_t count) { fv_param_count_ = count; }
size_t fv_param_count() const { return fv_param_count_; }
int GetPositionalArgsCount() const;
AnfNodePtr GetParameterByName(const std::string &name);
bool NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list);
@ -418,9 +417,9 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
bool has_kwarg_;
bool exist_multi_target_;
int kw_only_args_count_;
// Hyper param is placed on the top graph,
// Hyper param is used as free variable and placed on the top graph.
// and positioned in the end of the param list, so we record the number to trace the position.
size_t hyper_param_count_;
size_t fv_param_count_;
// Argument input list for the graph used to generate this graph.
bool is_generated_;
// CNode that calls 'return' primitive.

View File

@ -256,7 +256,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, const FuncGraphPtr
target_func_graph->set_has_vararg(func_graph->has_vararg());
target_func_graph->set_has_kwarg(func_graph->has_kwarg());
target_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
target_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
target_func_graph->set_fv_param_count(func_graph->fv_param_count());
target_func_graph->set_is_generate(func_graph->is_generated());
target_func_graph->set_stub(func_graph->stub());
target_func_graph->set_switch_input(func_graph->switch_input());
@ -822,7 +822,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
new_func_graph->set_has_vararg(func_graph->has_vararg());
new_func_graph->set_has_kwarg(func_graph->has_kwarg());
new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
new_func_graph->set_fv_param_count(func_graph->fv_param_count());
new_func_graph->set_is_generate(func_graph->is_generated());
new_func_graph->set_stub(func_graph->stub());
new_func_graph->set_switch_input(func_graph->switch_input());

View File

@ -196,7 +196,7 @@ void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
const std::vector<AnfNodePtr> &specialized_parameter_list,
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
MS_EXCEPTION_IF_NULL(specialized_graph);
for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) {
for (size_t i = 0; i < specialized_graph->parameters().size() - fv_param_count(); ++i) {
MS_EXCEPTION_IF_NULL(specialized_graph->parameters()[i]);
auto param_node = specialized_graph->parameters()[i]->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_node);
@ -222,10 +222,10 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
std::vector<size_t> pos_arg_indexes;
size_t arguments_count = args_spec_list.size();
if (hyper_param_count_ > arguments_count) {
if (fv_param_count_ > arguments_count) {
MS_LOG(EXCEPTION) << "The number of parameters in funcgraph cannot exceed the number of arguments.";
}
for (size_t i = 0; i < arguments_count - hyper_param_count_; i++) {
for (size_t i = 0; i < arguments_count - fv_param_count_; i++) {
MS_EXCEPTION_IF_NULL(args_spec_list[i]);
if (args_spec_list[i]->isa<abstract::AbstractKeywordArg>()) {
kwarg_list.push_back(args_spec_list[i]->cast<abstract::AbstractKeywordArgPtr>());
@ -243,7 +243,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
}
FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>());
size_t kwarg_count = kwarg_list.size();
int pos_args_input_count = SizeToInt((arguments_count - kwarg_count) - hyper_param_count_);
int pos_args_input_count = SizeToInt((arguments_count - kwarg_count) - fv_param_count_);
int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount());
int variable_args_count = pos_args_input_count - pos_args_count;
std::vector<AnfNodePtr> specialized_parameter_list;
@ -263,7 +263,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
// append hyper parameter to specialized_parameter_list
MS_EXCEPTION_IF_NULL(specialized_graph);
auto params = specialized_graph->parameters();
specialized_parameter_list.insert(specialized_parameter_list.end(), params.end() - SizeToInt(hyper_param_count_),
specialized_parameter_list.insert(specialized_parameter_list.end(), params.end() - SizeToInt(fv_param_count_),
params.end());
std::vector<AnfNodePtr> specialized_parameter_list_update(specialized_parameter_list.begin() + pos_arg_indexes.size(),
specialized_parameter_list.end());

View File

@ -1512,7 +1512,7 @@ bool MSANFModelParser::MSANFParseModelConfigureInfo(const mind_ir::ModelProto &m
bool MSANFModelParser::SetValueForTopGraphParameter(const FuncGraphPtr &topGraph,
const std::map<std::string, ValuePtr> &weights) {
size_t hyper_param_count = 0;
size_t fv_param_count = 0;
auto parameters = topGraph->parameters();
for (int i = parameters.size() - 1; i >= 0; --i) {
size_t index = IntToSize(i);
@ -1536,9 +1536,9 @@ bool MSANFModelParser::SetValueForTopGraphParameter(const FuncGraphPtr &topGraph
return false;
}
parameter->set_default_param(weights_iter->second);
hyper_param_count++;
fv_param_count++;
}
topGraph->set_hyper_param_count(hyper_param_count);
topGraph->set_fv_param_count(fv_param_count);
return true;
}

View File

@ -0,0 +1,290 @@
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test outermost net pass non_tensor inputs"""
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor, Parameter, ParameterTuple
from mindspore.ops import composite as C
from mindspore.ops import operations as P
import mindspore.ops as ops
from mindspore import context
@pytest.fixture(scope="module", autouse=True)
def setup_teardown():
yield
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
def construct(self, tensor_param_x, tuple_a, list_b, tensor_param_y, tensor_param_z, dict_c):
out = self.add(tensor_param_x, tuple_a[0])
out = self.sub(out, list_b[1][1]["y"])
out = self.add(out, tensor_param_y)
out = self.sub(out, tensor_param_z)
out = self.add(out, dict_c["u"])
return out
class GradNet(nn.Cell):
def __init__(self, net, get_all):
super(GradNet, self).__init__()
self.forward_net = net
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
self.grad_all = C.GradOperation(get_all=get_all)
def construct(self, tuple_a, tensor_param_x, list_b, tensor_param_y, tensor_param_z, dict_c):
return self.grad_all(self.forward_net)(tuple_a, tensor_param_x, list_b, tensor_param_y, tensor_param_z, dict_c)
tensor_x = Tensor(np.ones((2, 2), np.float32))
tensor_y = Tensor(np.ones((2, 2), np.float32) * 2)
tensor_z = Tensor(np.ones((2, 2), np.float32) * 3)
tensor_w = Tensor(np.ones((2, 2), np.float32) * 4)
tensor_p = Tensor(np.ones((2, 2), np.float32) * 5)
tensor_u = Tensor(np.ones((2, 2), np.float32) * 6)
tuple_arg = (tensor_x, tensor_y, tensor_z, tensor_w)
list_arg = [[tensor_x, tensor_x], [[tensor_x, tensor_y], {"x": tensor_x, "y": tensor_y, "z": tensor_z, "p": tensor_p}]]
dict_arg = {"x": tensor_x, "y": tensor_y, "u": tensor_u}
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_non_tensor_inputs(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Normal input type without tensor.
Expectation: No exception.
"""
context.set_context(mode=mode)
# grad first input
grad_fist_input_tensor_net = GradNet(Net(), get_all=False)
ret = grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_p, dict_arg)
assert np.allclose(ret.asnumpy(), np.ones((2, 2), np.float32))
# grad all inputs
grad_all_input_tensor_net = GradNet(Net(), get_all=True)
ret_all = grad_all_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_p, dict_arg)
assert len(ret_all) == 3
assert np.allclose(ret_all[0].asnumpy(), np.ones((2, 2), np.float32))
assert np.allclose(ret_all[1].asnumpy(), np.ones((2, 2), np.float32))
assert np.allclose(ret_all[2].asnumpy(), np.ones((2, 2), np.float32) * -1)
class GradNet1(nn.Cell):
def __init__(self, net, get_all):
super(GradNet1, self).__init__()
self.forward_net = net
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
self.grad_all = C.GradOperation(get_all=get_all)
def construct(self, tuple_a, tensor_a, list_b, tensor_b, tensor_c, dict_c):
return self.grad_all(self.forward_net)(tuple_a, tensor_a, list_b, tensor_b, tensor_c, dict_c)
# PyNative run error.
# Support context.PYNATIVE_MODE later.
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_grad_first_input_net(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Normal input type.
Expectation: No exception.
"""
class FirstInputTensorNet(nn.Cell):
def construct(self, tensor_a, tuple_a, list_b, tensor_b, tensor_c, dict_c):
return tensor_a + tuple_a[2] - list_b[1][1]["y"] + tensor_b - tensor_c + dict_c["y"]
context.set_context(mode=mode)
grad_fist_input_tensor_net = GradNet1(FirstInputTensorNet(), get_all=False)
res = grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_y, dict_arg)
print('res:', res)
assert np.allclose(res.asnumpy(), np.ones((2, 2), np.float32))
class TestCell(nn.Cell):
def __init__(self, param):
super().__init__()
self.a = Tensor(np.array([[1, 2], [3, 4]]))
self.param = param
def construct(self, x):
return self.a * self.param * x
class GradCellWithParameter(nn.Cell):
def __init__(self, net):
super().__init__()
self.net = net
self.grad = ops.GradOperation(get_all=True, get_by_list=True)
self.param = self.net.param
def construct(self, x):
return self.grad(self.net, self.param)(x)
class GradCell(nn.Cell):
def __init__(self, net):
super().__init__()
self.net = net
self.grad_all = ops.GradOperation(get_all=True)
def construct(self, x):
return self.grad_all(self.net)(x)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_grad_parameter_input(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Grad with Parameter as input type.
Expectation: No exception.
"""
context.set_context(mode=mode)
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
z = Tensor(np.array([[7, 8], [9, 0]]))
a = GradCell(TestCell(x))(y)
b = GradCell(TestCell(x))(z)
print(f'a: {a}')
print(f'b: {b}')
assert np.array_equal(a[0].asnumpy(), b[0].asnumpy())
# PyNative run error.
# Support context.PYNATIVE_MODE later.
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_grad_parameter_as_input_and_fv(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Grad with Parameters as input type and fv.
Expectation: No exception.
"""
context.set_context(mode=mode)
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
z = Tensor(np.array([[7, 8], [9, 0]]))
a = GradCellWithParameter(TestCell(x))(y)
b = GradCellWithParameter(TestCell(x))(z)
print(f'a: {a}')
print(f'b: {b}')
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
assert np.array_equal(a[1].asnumpy(), b[1].asnumpy())
# PyNative run error.
# Support context.PYNATIVE_MODE later.
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_grad_same_parameter_both_input_and_fv(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Grad with the same Parameter used as input type and fv at the same time.
Expectation: No exception.
"""
context.set_context(mode=mode)
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
y = Tensor(np.array([[1, 2], [3, 4]]))
a = GradCellWithParameter(TestCell(x))(x)
b = GradCellWithParameter(TestCell(x))(y)
print(f'a: {a}')
print(f'b: {b}')
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
assert np.array_equal(a[1].asnumpy(), b[1].asnumpy())
class TestCell2(nn.Cell):
def __init__(self, param1, param2):
super().__init__()
self.a = Tensor(np.array([[1, 2], [3, 4]]))
self.param1 = param1
self.param2 = param2
def construct(self, x):
return self.a * self.param1 * self.param2 * x
class GradCellWithParameterTuple(nn.Cell):
def __init__(self, net):
super().__init__()
self.net = net
self.grad = ops.GradOperation(get_all=True, get_by_list=True)
self.param1 = self.net.param1
self.param2 = self.net.param2
self.params = ParameterTuple([self.param1, self.param2])
def construct(self, x):
return self.grad(self.net, self.params)(x)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_grad_parameter_as_input_and_fv2(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Grad with Parameters as input type and fv. ParameterTuple as fv.
Expectation: No exception.
"""
context.set_context(mode=mode)
x1 = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x1')
x2 = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x2')
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
z = Tensor(np.array([[7, 8], [9, 0]]))
a = GradCellWithParameterTuple(TestCell2(x1, x2))(y)
b = GradCellWithParameterTuple(TestCell2(x1, x2))(z)
print(f'a: {a}')
print(f'b: {b}')
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
assert np.array_equal(a[1][0].asnumpy(), b[1][0].asnumpy())
assert np.array_equal(a[1][1].asnumpy(), b[1][1].asnumpy())

View File

@ -1,82 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test outermost net pass non_tensor inputs"""
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore import context
context.set_context(mode=context.PYNATIVE_MODE)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
def construct(self, tensor_x, tuple_a, list_b, tensor_y, tensor_z, dict_c):
out = self.add(tensor_x, tuple_a[0])
out = self.sub(out, list_b[1][1]["y"])
out = self.add(out, tensor_y)
out = self.sub(out, tensor_z)
out = self.add(out, dict_c["u"])
return out
class GradNet(nn.Cell):
def __init__(self, net, get_all):
super(GradNet, self).__init__()
self.forward_net = net
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
self.grad_all = C.GradOperation(get_all=get_all)
def construct(self, tuple_a, tensor_x, list_b, tensor_y, tensor_z, dict_c):
return self.grad_all(self.forward_net)(tuple_a, tensor_x, list_b, tensor_y, tensor_z, dict_c)
x = Tensor(np.ones((2, 2), np.float32))
y = Tensor(np.ones((2, 2), np.float32) * 2)
z = Tensor(np.ones((2, 2), np.float32) * 3)
w = Tensor(np.ones((2, 2), np.float32) * 4)
p = Tensor(np.ones((2, 2), np.float32) * 5)
u = Tensor(np.ones((2, 2), np.float32) * 6)
arg_t0 = (x, y, z, w)
arg_l0 = [[x, x], [[x, y], {"x": x, "y": y, "z": z, "p": p}]]
args_d0 = {"x": x, "y": y, "u": u}
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_non_tensor_inputs():
# grad first input
grad_fist_input_tensor_net = GradNet(Net(), get_all=False)
ret = grad_fist_input_tensor_net(z, arg_t0, arg_l0, w, p, args_d0)
assert np.allclose(ret.asnumpy(), np.ones((2, 2), np.float32))
# grad all inputs
grad_all_input_tensor_net = GradNet(Net(), get_all=True)
ret_all = grad_all_input_tensor_net(z, arg_t0, arg_l0, w, p, args_d0)
assert len(ret_all) == 3
assert np.allclose(ret_all[0].asnumpy(), np.ones((2, 2), np.float32))
assert np.allclose(ret_all[1].asnumpy(), np.ones((2, 2), np.float32))
assert np.allclose(ret_all[2].asnumpy(), np.ones((2, 2), np.float32) * -1)

View File

@ -36,7 +36,7 @@ constexpr auto kDependRealInputSize = 2;
ParameterPtr TestCreateParameter(const KernelGraphPtr &g, const std::string &name,
const abstract::AbstractBasePtr &abstract) {
MS_EXCEPTION_IF_NULL(g);
auto parameter = g->AddWeightParameter(name);
auto parameter = g->AddFvParameter(name, abstract->BuildValue());
if (parameter == nullptr) {
MS_LOG(ERROR) << "Cannot add weight parameter!";
}

View File

@ -19,9 +19,9 @@ import pytest
import mindspore.nn as nn
from mindspore.common import mutable
from mindspore import Tensor, Parameter, ParameterTuple
from mindspore import context
from mindspore.ops import composite as C
import mindspore.ops as ops
from mindspore import context
@pytest.fixture(scope="module", autouse=True)
@ -91,29 +91,7 @@ def test_grad_first_input_net(mode):
context.set_context(mode=mode)
grad_fist_input_tensor_net = GradNet1(FirstInputTensorNet(), get_all=False)
res = grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_y, dict_arg)
print('res:', res)
assert np.allclose(res.asnumpy(), np.ones((2, 2), np.float32))
# PyNative run error.
# Support context.PYNATIVE_MODE later.
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_grad_first_input_net_pynative_error(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Normal input type.
Expectation: No exception.
"""
class FirstInputTensorNet(nn.Cell):
def construct(self, tensor_a, tuple_a, list_b, tensor_b, tensor_c, dict_c):
return tensor_a + tuple_a[2] - list_b[1][1]["y"] + tensor_b - tensor_c + dict_c["y"]
context.set_context(mode=mode)
grad_fist_input_tensor_net = GradNet1(FirstInputTensorNet(), get_all=False)
res = grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_y, dict_arg)
print('res:', res)
assert np.allclose(res.asnumpy(), np.ones((2, 2), np.float32))
grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_y, dict_arg)
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
@ -149,7 +127,6 @@ def test_outermost_net_pass_parameter(mode):
# Support the Parameter as outermost input.
# Support context.PYNATIVE_MODE UT later.
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_outermost_net_pass_tuple_including_parameter(mode):
"""
@ -163,7 +140,6 @@ def test_outermost_net_pass_tuple_including_parameter(mode):
# Support the Parameter as outermost input.
# Support context.PYNATIVE_MODE UT later.
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_outermost_net_pass_list_including_parameter(mode):
"""
@ -177,7 +153,6 @@ def test_outermost_net_pass_list_including_parameter(mode):
# Support the Parameter as outermost input.
# Support context.PYNATIVE_MODE UT later.
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_grad_net_pass_dict_including_parameter(mode):
"""
@ -190,96 +165,6 @@ def test_grad_net_pass_dict_including_parameter(mode):
forward_net(tuple_arg, tensor_z, list_arg, SCALAR_NUM, SCALAR_NUM, mutable_dict, flag_0)
class TestCell(nn.Cell):
def __init__(self, param):
super().__init__()
self.a = Tensor(np.array([[1, 2], [3, 4]]))
self.param = param
def construct(self, x):
return self.a * self.param * x
class GradCellWithParameter(nn.Cell):
def __init__(self, net):
super().__init__()
self.net = net
self.grad = ops.GradOperation(get_all=True, get_by_list=True)
self.param = self.net.param
def construct(self, x):
return self.grad(self.net, self.param)(x)
class GradCell(nn.Cell):
def __init__(self, net):
super().__init__()
self.net = net
self.grad_all = ops.GradOperation(get_all=True)
def construct(self, x):
return self.grad_all(self.net)(x)
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
def test_grad_parameter_input(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Grad with Parameter as input type.
Expectation: No exception.
"""
context.set_context(mode=mode)
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
z = Tensor(np.array([[7, 8], [9, 0]]))
a = GradCell(TestCell(x))(y)
b = GradCell(TestCell(x))(z)
print(f'a: {a}')
print(f'b: {b}')
assert np.array_equal(a[0].asnumpy(), b[0].asnumpy())
# PyNative run error.
# Support context.PYNATIVE_MODE later.
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_grad_parameter_as_input_and_fv(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Grad with Parameters as input type and fv.
Expectation: No exception.
"""
context.set_context(mode=mode)
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
z = Tensor(np.array([[7, 8], [9, 0]]))
a = GradCellWithParameter(TestCell(x))(y)
b = GradCellWithParameter(TestCell(x))(z)
print(f'a: {a}')
print(f'b: {b}')
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
assert np.array_equal(a[1].asnumpy(), b[1].asnumpy())
# PyNative run error.
# Support context.PYNATIVE_MODE later.
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
def test_grad_same_parameter_both_input_and_fv(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Grad with the same Parameter used as input type and fv at the same time.
Expectation: No exception.
"""
context.set_context(mode=mode)
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
y = Tensor(np.array([[1, 2], [3, 4]]))
a = GradCellWithParameter(TestCell(x))(x)
b = GradCellWithParameter(TestCell(x))(y)
print(f'a: {a}')
print(f'b: {b}')
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
assert np.array_equal(a[1].asnumpy(), b[1].asnumpy())
class TestCell2(nn.Cell):
def __init__(self, param1, param2):
super().__init__()
@ -329,7 +214,7 @@ class GradCellWithTupleOfParameter(nn.Cell):
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
def test_grad_parameter_as_input_and_fv2(mode):
def test_grad_parameter_tuple(mode):
"""
Feature: Construct()/ms_function input type with back propagate.
Description: Grad with Parameters as input type and fv. ParameterTuple as fv.
@ -340,13 +225,8 @@ def test_grad_parameter_as_input_and_fv2(mode):
x2 = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x2')
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
z = Tensor(np.array([[7, 8], [9, 0]]))
a = GradCellWithParameterTuple(TestCell2(x1, x2))(y)
b = GradCellWithParameterTuple(TestCell2(x1, x2))(z)
print(f'a: {a}')
print(f'b: {b}')
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
assert np.array_equal(a[1][0].asnumpy(), b[1][0].asnumpy())
assert np.array_equal(a[1][1].asnumpy(), b[1][1].asnumpy())
GradCellWithParameterTuple(TestCell2(x1, x2))(y)
GradCellWithParameterTuple(TestCell2(x1, x2))(z)
@pytest.mark.skip(reason='Not support list or tuple of parameters as GradOperation inputs by now')