forked from mindspore-Ecosystem/mindspore
!35174 Change the weight parameter to FV parameter in FuncGraph.
Merge pull request !35174 from 张清华/opt_parameter
This commit is contained in:
commit
6c1ea8074e
|
@ -338,13 +338,13 @@ void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager
|
||||||
}),
|
}),
|
||||||
root_parameters.end());
|
root_parameters.end());
|
||||||
size_t remove_param_count = origin_param_count - root_parameters.size();
|
size_t remove_param_count = origin_param_count - root_parameters.size();
|
||||||
size_t hyper_param_count = root_graph->hyper_param_count();
|
size_t fv_param_count = root_graph->fv_param_count();
|
||||||
if (remove_param_count > hyper_param_count) {
|
if (remove_param_count > fv_param_count) {
|
||||||
MS_LOG(ERROR) << "The number of deleted parameters cannot exceed the number of original parameters.";
|
MS_LOG(ERROR) << "The number of deleted parameters cannot exceed the number of original parameters.";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
hyper_param_count = hyper_param_count - remove_param_count;
|
fv_param_count = fv_param_count - remove_param_count;
|
||||||
root_graph->set_hyper_param_count(hyper_param_count);
|
root_graph->set_fv_param_count(fv_param_count);
|
||||||
manager->SetParameters(root_graph, root_parameters);
|
manager->SetParameters(root_graph, root_parameters);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -176,9 +176,9 @@ static inline void AdjustCallerArgs(const FuncGraphPtr &called, const CNodePtr &
|
||||||
// 2. The arguments in caller may be less than the formal parameters in called as some parameters can have
|
// 2. The arguments in caller may be less than the formal parameters in called as some parameters can have
|
||||||
// default value.
|
// default value.
|
||||||
if (!called->has_vararg() &&
|
if (!called->has_vararg() &&
|
||||||
caller->inputs().size() > (1 + called->GetPositionalArgsCount() + called->hyper_param_count())) {
|
caller->inputs().size() > (1 + called->GetPositionalArgsCount() + called->fv_param_count())) {
|
||||||
size_t start_offset = called->GetPositionalArgsCount() + 1;
|
size_t start_offset = called->GetPositionalArgsCount() + 1;
|
||||||
size_t end_offset = called->hyper_param_count();
|
size_t end_offset = called->fv_param_count();
|
||||||
new_args.erase(new_args.begin() + start_offset, new_args.end() - end_offset);
|
new_args.erase(new_args.begin() + start_offset, new_args.end() - end_offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -52,9 +52,7 @@ ParamMap AddCacheParameters(const FuncGraphPtr &graph, const ParamSet ¶meter
|
||||||
auto cache_name = ori_param_name + "_cache";
|
auto cache_name = ori_param_name + "_cache";
|
||||||
new_param_info->set_name(cache_name);
|
new_param_info->set_name(cache_name);
|
||||||
new_tensor->set_param_info(new_param_info);
|
new_tensor->set_param_info(new_param_info);
|
||||||
auto cache_param = graph->AddWeightParameter(cache_name);
|
auto cache_param = graph->AddFvParameter(cache_name, new_tensor);
|
||||||
cache_param->set_default_param(MakeValue(new_tensor));
|
|
||||||
cache_param->set_abstract(new_tensor->ToAbstract());
|
|
||||||
cache_host_params_map[cache_param] = param;
|
cache_host_params_map[cache_param] = param;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -260,10 +258,7 @@ AnfNodePtr InitHashMap(const FuncGraphPtr &func_graph, const int64_t host_size,
|
||||||
std::string hashmap_name = "cache_hashmap";
|
std::string hashmap_name = "cache_hashmap";
|
||||||
new_param_info->set_name(hashmap_name);
|
new_param_info->set_name(hashmap_name);
|
||||||
new_tensor->set_param_info(new_param_info);
|
new_tensor->set_param_info(new_param_info);
|
||||||
auto hashmap = func_graph->AddWeightParameter(hashmap_name);
|
return func_graph->AddFvParameter(hashmap_name, new_tensor);
|
||||||
hashmap->set_default_param(MakeValue(new_tensor));
|
|
||||||
hashmap->set_abstract(new_tensor->ToAbstract());
|
|
||||||
return hashmap;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr InitStep(const FuncGraphPtr &func_graph, TypeId type_id) {
|
AnfNodePtr InitStep(const FuncGraphPtr &func_graph, TypeId type_id) {
|
||||||
|
@ -273,10 +268,7 @@ AnfNodePtr InitStep(const FuncGraphPtr &func_graph, TypeId type_id) {
|
||||||
std::string step_name = "cache_step";
|
std::string step_name = "cache_step";
|
||||||
new_param_info->set_name(step_name);
|
new_param_info->set_name(step_name);
|
||||||
new_tensor->set_param_info(new_param_info);
|
new_tensor->set_param_info(new_param_info);
|
||||||
auto step = func_graph->AddWeightParameter(step_name);
|
return func_graph->AddFvParameter(step_name, new_tensor);
|
||||||
step->set_default_param(MakeValue(new_tensor));
|
|
||||||
step->set_abstract(new_tensor->ToAbstract());
|
|
||||||
return step;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr CreateMapCacheIdx(const FuncGraphPtr &func_graph, const AnfNodePtr &indices,
|
AnfNodePtr CreateMapCacheIdx(const FuncGraphPtr &func_graph, const AnfNodePtr &indices,
|
||||||
|
@ -540,11 +532,7 @@ AnfNodePtr CreateOutputNodeParam(const FuncGraphPtr &graph, const AnfNodePtr &or
|
||||||
auto new_param_name = name + "_pipe";
|
auto new_param_name = name + "_pipe";
|
||||||
new_param_info->set_name(new_param_name);
|
new_param_info->set_name(new_param_name);
|
||||||
new_tensor->set_param_info(new_param_info);
|
new_tensor->set_param_info(new_param_info);
|
||||||
auto new_param = graph->AddWeightParameter(new_param_name);
|
return graph->AddFvParameter(new_param_name, new_tensor);
|
||||||
new_param->set_default_param(MakeValue(new_tensor));
|
|
||||||
auto abs_tensor = new_tensor->ToAbstract();
|
|
||||||
new_param->set_abstract(abs_tensor);
|
|
||||||
return new_param->cast<AnfNodePtr>();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfMap CreateOtherPipeParams(const FuncGraphPtr &graph, const AnfSet &no_ref_params) {
|
AnfMap CreateOtherPipeParams(const FuncGraphPtr &graph, const AnfSet &no_ref_params) {
|
||||||
|
|
|
@ -1085,7 +1085,7 @@ void PipelineTransformer::ModifyParameterList() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto del_num = parameters.size() - parameter_list.size();
|
auto del_num = parameters.size() - parameter_list.size();
|
||||||
root_->set_hyper_param_count(root_->hyper_param_count() - del_num);
|
root_->set_fv_param_count(root_->fv_param_count() - del_num);
|
||||||
manager_->SetParameters(root_, parameter_list);
|
manager_->SetParameters(root_, parameter_list);
|
||||||
}
|
}
|
||||||
} // namespace parallel
|
} // namespace parallel
|
||||||
|
|
|
@ -175,14 +175,9 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (para_node == nullptr) {
|
if (para_node == nullptr) {
|
||||||
auto node = top_func_graph->AddWeightParameter(param_name);
|
|
||||||
auto value = py::cast<tensor::MetaTensorPtr>(obj);
|
auto value = py::cast<tensor::MetaTensorPtr>(obj);
|
||||||
|
para_node = top_func_graph->AddFvParameter(param_name, value);
|
||||||
param_obj_ids.emplace_back(obj_id);
|
param_obj_ids.emplace_back(obj_id);
|
||||||
node->set_default_param(value);
|
|
||||||
// Set abstract for parameter
|
|
||||||
auto abs = value->ToAbstract();
|
|
||||||
node->set_abstract(abs);
|
|
||||||
para_node = node;
|
|
||||||
MS_LOG(DEBUG) << "Created a new weight parameter for " << func_graph->ToString()
|
MS_LOG(DEBUG) << "Created a new weight parameter for " << func_graph->ToString()
|
||||||
<< ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
|
<< ", param: " << para_node->DebugString() << ", top_func_graph: " << top_func_graph->ToString();
|
||||||
}
|
}
|
||||||
|
@ -224,8 +219,8 @@ void ConvertLoadedGraph(const FuncGraphPtr &func_graph, const ValuePtr &value) {
|
||||||
|
|
||||||
// Update top_graph
|
// Update top_graph
|
||||||
top_graph->add_parameter(param_ptr);
|
top_graph->add_parameter(param_ptr);
|
||||||
size_t hyper_param_count = top_graph->hyper_param_count();
|
size_t fv_param_count = top_graph->fv_param_count();
|
||||||
top_graph->set_hyper_param_count(hyper_param_count + 1);
|
top_graph->set_fv_param_count(fv_param_count + 1);
|
||||||
} else {
|
} else {
|
||||||
input_params.push_back(param_ptr);
|
input_params.push_back(param_ptr);
|
||||||
}
|
}
|
||||||
|
|
|
@ -477,8 +477,8 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
|
||||||
auto func_graph = output->func_graph();
|
auto func_graph = output->func_graph();
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
auto params = func_graph->parameters();
|
auto params = func_graph->parameters();
|
||||||
if ((args.size() + func_graph->hyper_param_count()) != params.size()) {
|
if ((args.size() + func_graph->fv_param_count()) != params.size()) {
|
||||||
MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->hyper_param_count()
|
MS_LOG(EXCEPTION) << "Input size " << args.size() << " add Parameter count " << func_graph->fv_param_count()
|
||||||
<< " not equal to graph input size " << params.size() << ", let graph to be executed.";
|
<< " not equal to graph input size " << params.size() << ", let graph to be executed.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -487,9 +487,9 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
|
||||||
MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
|
MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
|
||||||
}
|
}
|
||||||
size_t index = it - params.cbegin();
|
size_t index = it - params.cbegin();
|
||||||
if (index >= args.size() + func_graph->hyper_param_count()) {
|
if (index >= args.size() + func_graph->fv_param_count()) {
|
||||||
MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size()
|
MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size()
|
||||||
<< " add Parameter count " << func_graph->hyper_param_count() << ".";
|
<< " add Parameter count " << func_graph->fv_param_count() << ".";
|
||||||
}
|
}
|
||||||
if (index < args.size()) {
|
if (index < args.size()) {
|
||||||
*ret_val = args[index];
|
*ret_val = args[index];
|
||||||
|
|
|
@ -41,7 +41,7 @@ FuncGraph::FuncGraph(GraphDebugInfoPtr &&debug_info)
|
||||||
has_kwarg_(false),
|
has_kwarg_(false),
|
||||||
exist_multi_target_(false),
|
exist_multi_target_(false),
|
||||||
kw_only_args_count_(0),
|
kw_only_args_count_(0),
|
||||||
hyper_param_count_(0),
|
fv_param_count_(0),
|
||||||
is_generated_(false),
|
is_generated_(false),
|
||||||
return_(nullptr),
|
return_(nullptr),
|
||||||
manager_(),
|
manager_(),
|
||||||
|
@ -91,54 +91,56 @@ const std::vector<AnfNodePtr> FuncGraph::get_inputs() const {
|
||||||
|
|
||||||
ParameterPtr FuncGraph::add_parameter() {
|
ParameterPtr FuncGraph::add_parameter() {
|
||||||
FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
|
FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
|
||||||
ParameterPtr p = std::make_shared<Parameter>(this_func_graph);
|
ParameterPtr param = std::make_shared<Parameter>(this_func_graph);
|
||||||
add_parameter(p);
|
add_parameter(param);
|
||||||
return p;
|
return param;
|
||||||
}
|
}
|
||||||
|
|
||||||
ParameterPtr FuncGraph::add_parameter(NodeDebugInfoPtr &&debug_info) {
|
ParameterPtr FuncGraph::add_parameter(NodeDebugInfoPtr &&debug_info) {
|
||||||
FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
|
FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
|
||||||
ParameterPtr p = std::make_shared<Parameter>(this_func_graph, std::move(debug_info));
|
ParameterPtr param = std::make_shared<Parameter>(this_func_graph, std::move(debug_info));
|
||||||
add_parameter(p);
|
add_parameter(param);
|
||||||
return p;
|
return param;
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraph::add_parameter(const ParameterPtr &p) {
|
void FuncGraph::add_parameter(const ParameterPtr ¶m) {
|
||||||
if (manager_.lock()) {
|
if (manager_.lock()) {
|
||||||
manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p);
|
manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), param);
|
||||||
} else {
|
} else {
|
||||||
parameters_.push_back(p);
|
parameters_.push_back(param);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ParameterPtr FuncGraph::InsertFrontParameter() {
|
ParameterPtr FuncGraph::InsertFrontParameter() {
|
||||||
FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
|
FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
|
||||||
ParameterPtr p = std::make_shared<Parameter>(this_func_graph);
|
ParameterPtr param = std::make_shared<Parameter>(this_func_graph);
|
||||||
InsertFrontParameter(p);
|
InsertFrontParameter(param);
|
||||||
return p;
|
return param;
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraph::InsertFrontParameter(const ParameterPtr &p) {
|
void FuncGraph::InsertFrontParameter(const ParameterPtr ¶m) {
|
||||||
if (manager_.lock()) {
|
if (manager_.lock()) {
|
||||||
manager_.lock()->InsertFrontParameter(shared_from_base<FuncGraph>(), p);
|
manager_.lock()->InsertFrontParameter(shared_from_base<FuncGraph>(), param);
|
||||||
} else {
|
} else {
|
||||||
PrependParameter(p);
|
PrependParameter(param);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
|
ParameterPtr FuncGraph::AddFvParameter(const std::string &name, const ValuePtr &default_value) {
|
||||||
FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
|
FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
|
||||||
ParameterPtr p = std::make_shared<Parameter>(this_graph);
|
ParameterPtr param = std::make_shared<Parameter>(this_graph);
|
||||||
p->set_name(name);
|
param->set_name(name);
|
||||||
p->debug_info()->set_name(name);
|
param->debug_info()->set_name(name);
|
||||||
|
MS_EXCEPTION_IF_NULL(default_value);
|
||||||
|
param->set_default_param(default_value);
|
||||||
|
param->set_abstract(default_value->ToAbstract());
|
||||||
if (manager_.lock()) {
|
if (manager_.lock()) {
|
||||||
manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p);
|
manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), param);
|
||||||
} else {
|
} else {
|
||||||
parameters_.push_back(p);
|
parameters_.push_back(param);
|
||||||
}
|
}
|
||||||
hyper_param_count_++;
|
++fv_param_count_;
|
||||||
return p;
|
return param;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool FuncGraph::has_flag(const std::string &key) const {
|
bool FuncGraph::has_flag(const std::string &key) const {
|
||||||
|
@ -573,11 +575,11 @@ AnfNodePtr FuncGraph::GetVariableArgParameter() {
|
||||||
min_param_num += 1;
|
min_param_num += 1;
|
||||||
}
|
}
|
||||||
min_param_num += kw_only_args_count_;
|
min_param_num += kw_only_args_count_;
|
||||||
min_param_num += hyper_param_count_;
|
min_param_num += fv_param_count_;
|
||||||
|
|
||||||
if (parameters_.size() < min_param_num) {
|
if (parameters_.size() < min_param_num) {
|
||||||
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size()
|
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size()
|
||||||
<< " which less than the sum of following: hyper_param_count: " << hyper_param_count_
|
<< " which less than the sum of following: fv_param_count: " << fv_param_count_
|
||||||
<< ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_
|
<< ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_
|
||||||
<< ", kw_only_args_count_: " << kw_only_args_count_;
|
<< ", kw_only_args_count_: " << kw_only_args_count_;
|
||||||
}
|
}
|
||||||
|
@ -598,22 +600,22 @@ std::string FuncGraph::GetVariableArgName() {
|
||||||
|
|
||||||
AnfNodePtr FuncGraph::GetVariableKwargParameter() {
|
AnfNodePtr FuncGraph::GetVariableKwargParameter() {
|
||||||
if (has_kwarg_) {
|
if (has_kwarg_) {
|
||||||
if (parameters_.size() < hyper_param_count_ + 1) {
|
if (parameters_.size() < fv_param_count_ + 1) {
|
||||||
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
|
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", fv_param_count is " << fv_param_count_
|
||||||
<< hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
|
<< ", parameters is less than 1 + fv_param_count";
|
||||||
}
|
}
|
||||||
return parameters_[(parameters_.size() - hyper_param_count_) - 1];
|
return parameters_[(parameters_.size() - fv_param_count_) - 1];
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string FuncGraph::GetVariableKwargName() {
|
std::string FuncGraph::GetVariableKwargName() {
|
||||||
if (has_kwarg_) {
|
if (has_kwarg_) {
|
||||||
if (parameters_.size() < hyper_param_count_ + 1) {
|
if (parameters_.size() < fv_param_count_ + 1) {
|
||||||
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is "
|
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", fv_param_count is " << fv_param_count_
|
||||||
<< hyper_param_count_ << ", parameters is less than 1 + hyper_param_count";
|
<< ", parameters is less than 1 + fv_param_count";
|
||||||
}
|
}
|
||||||
const auto ¶meter = parameters_[(parameters_.size() - hyper_param_count_) - 1]->cast<ParameterPtr>();
|
const auto ¶meter = parameters_[(parameters_.size() - fv_param_count_) - 1]->cast<ParameterPtr>();
|
||||||
MS_EXCEPTION_IF_NULL(parameter);
|
MS_EXCEPTION_IF_NULL(parameter);
|
||||||
return parameter->name();
|
return parameter->name();
|
||||||
}
|
}
|
||||||
|
@ -637,17 +639,17 @@ AnfNodePtrList FuncGraph::GetKwOnlyArgsParameters() {
|
||||||
varargs_kwargs_num += 1;
|
varargs_kwargs_num += 1;
|
||||||
}
|
}
|
||||||
min_param_num += kw_only_args_count_;
|
min_param_num += kw_only_args_count_;
|
||||||
min_param_num += hyper_param_count_;
|
min_param_num += fv_param_count_;
|
||||||
|
|
||||||
if (parameters_.size() < min_param_num) {
|
if (parameters_.size() < min_param_num) {
|
||||||
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size()
|
MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size()
|
||||||
<< " which less than the sum of following: hyper_param_count: " << hyper_param_count_
|
<< " which less than the sum of following: fv_param_count: " << fv_param_count_
|
||||||
<< ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_
|
<< ", has_vararg: " << has_vararg_ << ", has_kwarg: " << has_kwarg_
|
||||||
<< ", kw_only_args_count: " << kw_only_args_count_;
|
<< ", kw_only_args_count: " << kw_only_args_count_;
|
||||||
}
|
}
|
||||||
size_t kw_only_args_start_offset = parameters_.size() - min_param_num;
|
size_t kw_only_args_start_offset = parameters_.size() - min_param_num;
|
||||||
std::copy(parameters_.cbegin() + kw_only_args_start_offset,
|
std::copy(parameters_.cbegin() + kw_only_args_start_offset, parameters_.cend() - fv_param_count_ - varargs_kwargs_num,
|
||||||
parameters_.cend() - hyper_param_count_ - varargs_kwargs_num, std::back_inserter(kw_only_args));
|
std::back_inserter(kw_only_args));
|
||||||
return kw_only_args;
|
return kw_only_args;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -659,7 +661,7 @@ int FuncGraph::GetPositionalArgsCount() const {
|
||||||
if (has_vararg_) {
|
if (has_vararg_) {
|
||||||
count--;
|
count--;
|
||||||
}
|
}
|
||||||
return (count - kw_only_args_count_) - SizeToInt(hyper_param_count_);
|
return (count - kw_only_args_count_) - SizeToInt(fv_param_count_);
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) {
|
AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) {
|
||||||
|
@ -763,13 +765,6 @@ CNodePtr FuncGraph::NewCNodeInOrder(const PrimitivePtr &primitive, const std::ve
|
||||||
return NewCNodeInOrder(std::move(input_node_list));
|
return NewCNodeInOrder(std::move(input_node_list));
|
||||||
}
|
}
|
||||||
|
|
||||||
ParameterPtr FuncGraph::add_weight(const tensor::MetaTensorPtr &meta_tensor) {
|
|
||||||
auto parameter = add_parameter();
|
|
||||||
parameter->set_default_param(MakeValue(meta_tensor));
|
|
||||||
parameter->set_abstract(meta_tensor->ToAbstract());
|
|
||||||
return parameter;
|
|
||||||
}
|
|
||||||
|
|
||||||
void FuncGraph::SetMultiTarget() {
|
void FuncGraph::SetMultiTarget() {
|
||||||
auto graph_manager = manager();
|
auto graph_manager = manager();
|
||||||
MS_EXCEPTION_IF_NULL(graph_manager);
|
MS_EXCEPTION_IF_NULL(graph_manager);
|
||||||
|
|
|
@ -132,8 +132,8 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
||||||
void PrependParameter(const ParameterPtr &p) { parameters_.insert(parameters_.begin(), p); }
|
void PrependParameter(const ParameterPtr &p) { parameters_.insert(parameters_.begin(), p); }
|
||||||
void set_parameters(const std::vector<AnfNodePtr> ¶ms) { parameters_ = params; }
|
void set_parameters(const std::vector<AnfNodePtr> ¶ms) { parameters_ = params; }
|
||||||
void set_parameters(std::vector<AnfNodePtr> &¶ms) { parameters_ = std::move(params); }
|
void set_parameters(std::vector<AnfNodePtr> &¶ms) { parameters_ = std::move(params); }
|
||||||
// Add a weight parameter with specific name.
|
// Add a FV weight parameter with specific name.
|
||||||
ParameterPtr AddWeightParameter(const std::string &name);
|
ParameterPtr AddFvParameter(const std::string &name, const ValuePtr &default_value);
|
||||||
|
|
||||||
// Create a cnode with given inputs, bound to this graph.
|
// Create a cnode with given inputs, bound to this graph.
|
||||||
virtual CNodePtr NewCNode(std::vector<AnfNodePtr> &&inputs);
|
virtual CNodePtr NewCNode(std::vector<AnfNodePtr> &&inputs);
|
||||||
|
@ -154,7 +154,6 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
||||||
// Create a cnode with given inputs, put it to order list after the position node.
|
// Create a cnode with given inputs, put it to order list after the position node.
|
||||||
CNodePtr NewCNodeAfter(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs);
|
CNodePtr NewCNodeAfter(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs);
|
||||||
|
|
||||||
virtual ParameterPtr add_weight(const tensor::MetaTensorPtr &meta_tensor);
|
|
||||||
// Functions for handling variable argument, keyword-only arguments and variable keyword argument.
|
// Functions for handling variable argument, keyword-only arguments and variable keyword argument.
|
||||||
AnfNodePtr GetDefaultValueByName(const std::string &name);
|
AnfNodePtr GetDefaultValueByName(const std::string &name);
|
||||||
void set_param_default_value(const std::string &name, const AnfNodePtr &node) {
|
void set_param_default_value(const std::string &name, const AnfNodePtr &node) {
|
||||||
|
@ -176,8 +175,8 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
||||||
AnfNodePtr GetVariableKwargParameter();
|
AnfNodePtr GetVariableKwargParameter();
|
||||||
std::string GetVariableKwargName();
|
std::string GetVariableKwargName();
|
||||||
AnfNodePtrList GetKwOnlyArgsParameters();
|
AnfNodePtrList GetKwOnlyArgsParameters();
|
||||||
void set_hyper_param_count(size_t count) { hyper_param_count_ = count; }
|
void set_fv_param_count(size_t count) { fv_param_count_ = count; }
|
||||||
size_t hyper_param_count() const { return hyper_param_count_; }
|
size_t fv_param_count() const { return fv_param_count_; }
|
||||||
int GetPositionalArgsCount() const;
|
int GetPositionalArgsCount() const;
|
||||||
AnfNodePtr GetParameterByName(const std::string &name);
|
AnfNodePtr GetParameterByName(const std::string &name);
|
||||||
bool NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list);
|
bool NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list);
|
||||||
|
@ -418,9 +417,9 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
||||||
bool has_kwarg_;
|
bool has_kwarg_;
|
||||||
bool exist_multi_target_;
|
bool exist_multi_target_;
|
||||||
int kw_only_args_count_;
|
int kw_only_args_count_;
|
||||||
// Hyper param is placed on the top graph,
|
// Hyper param is used as free variable and placed on the top graph.
|
||||||
// and positioned in the end of the param list, so we record the number to trace the position.
|
// and positioned in the end of the param list, so we record the number to trace the position.
|
||||||
size_t hyper_param_count_;
|
size_t fv_param_count_;
|
||||||
// Argument input list for the graph used to generate this graph.
|
// Argument input list for the graph used to generate this graph.
|
||||||
bool is_generated_;
|
bool is_generated_;
|
||||||
// CNode that calls 'return' primitive.
|
// CNode that calls 'return' primitive.
|
||||||
|
|
|
@ -256,7 +256,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, const FuncGraphPtr
|
||||||
target_func_graph->set_has_vararg(func_graph->has_vararg());
|
target_func_graph->set_has_vararg(func_graph->has_vararg());
|
||||||
target_func_graph->set_has_kwarg(func_graph->has_kwarg());
|
target_func_graph->set_has_kwarg(func_graph->has_kwarg());
|
||||||
target_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
|
target_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
|
||||||
target_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
|
target_func_graph->set_fv_param_count(func_graph->fv_param_count());
|
||||||
target_func_graph->set_is_generate(func_graph->is_generated());
|
target_func_graph->set_is_generate(func_graph->is_generated());
|
||||||
target_func_graph->set_stub(func_graph->stub());
|
target_func_graph->set_stub(func_graph->stub());
|
||||||
target_func_graph->set_switch_input(func_graph->switch_input());
|
target_func_graph->set_switch_input(func_graph->switch_input());
|
||||||
|
@ -822,7 +822,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
|
||||||
new_func_graph->set_has_vararg(func_graph->has_vararg());
|
new_func_graph->set_has_vararg(func_graph->has_vararg());
|
||||||
new_func_graph->set_has_kwarg(func_graph->has_kwarg());
|
new_func_graph->set_has_kwarg(func_graph->has_kwarg());
|
||||||
new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
|
new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
|
||||||
new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
|
new_func_graph->set_fv_param_count(func_graph->fv_param_count());
|
||||||
new_func_graph->set_is_generate(func_graph->is_generated());
|
new_func_graph->set_is_generate(func_graph->is_generated());
|
||||||
new_func_graph->set_stub(func_graph->stub());
|
new_func_graph->set_stub(func_graph->stub());
|
||||||
new_func_graph->set_switch_input(func_graph->switch_input());
|
new_func_graph->set_switch_input(func_graph->switch_input());
|
||||||
|
|
|
@ -196,7 +196,7 @@ void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
|
||||||
const std::vector<AnfNodePtr> &specialized_parameter_list,
|
const std::vector<AnfNodePtr> &specialized_parameter_list,
|
||||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
|
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
|
||||||
MS_EXCEPTION_IF_NULL(specialized_graph);
|
MS_EXCEPTION_IF_NULL(specialized_graph);
|
||||||
for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) {
|
for (size_t i = 0; i < specialized_graph->parameters().size() - fv_param_count(); ++i) {
|
||||||
MS_EXCEPTION_IF_NULL(specialized_graph->parameters()[i]);
|
MS_EXCEPTION_IF_NULL(specialized_graph->parameters()[i]);
|
||||||
auto param_node = specialized_graph->parameters()[i]->cast<ParameterPtr>();
|
auto param_node = specialized_graph->parameters()[i]->cast<ParameterPtr>();
|
||||||
MS_EXCEPTION_IF_NULL(param_node);
|
MS_EXCEPTION_IF_NULL(param_node);
|
||||||
|
@ -222,10 +222,10 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
|
||||||
std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
|
std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
|
||||||
std::vector<size_t> pos_arg_indexes;
|
std::vector<size_t> pos_arg_indexes;
|
||||||
size_t arguments_count = args_spec_list.size();
|
size_t arguments_count = args_spec_list.size();
|
||||||
if (hyper_param_count_ > arguments_count) {
|
if (fv_param_count_ > arguments_count) {
|
||||||
MS_LOG(EXCEPTION) << "The number of parameters in funcgraph cannot exceed the number of arguments.";
|
MS_LOG(EXCEPTION) << "The number of parameters in funcgraph cannot exceed the number of arguments.";
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < arguments_count - hyper_param_count_; i++) {
|
for (size_t i = 0; i < arguments_count - fv_param_count_; i++) {
|
||||||
MS_EXCEPTION_IF_NULL(args_spec_list[i]);
|
MS_EXCEPTION_IF_NULL(args_spec_list[i]);
|
||||||
if (args_spec_list[i]->isa<abstract::AbstractKeywordArg>()) {
|
if (args_spec_list[i]->isa<abstract::AbstractKeywordArg>()) {
|
||||||
kwarg_list.push_back(args_spec_list[i]->cast<abstract::AbstractKeywordArgPtr>());
|
kwarg_list.push_back(args_spec_list[i]->cast<abstract::AbstractKeywordArgPtr>());
|
||||||
|
@ -243,7 +243,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
|
||||||
}
|
}
|
||||||
FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>());
|
FuncGraphPtr specialized_graph = BasicClone(shared_from_base<FuncGraph>());
|
||||||
size_t kwarg_count = kwarg_list.size();
|
size_t kwarg_count = kwarg_list.size();
|
||||||
int pos_args_input_count = SizeToInt((arguments_count - kwarg_count) - hyper_param_count_);
|
int pos_args_input_count = SizeToInt((arguments_count - kwarg_count) - fv_param_count_);
|
||||||
int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount());
|
int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount());
|
||||||
int variable_args_count = pos_args_input_count - pos_args_count;
|
int variable_args_count = pos_args_input_count - pos_args_count;
|
||||||
std::vector<AnfNodePtr> specialized_parameter_list;
|
std::vector<AnfNodePtr> specialized_parameter_list;
|
||||||
|
@ -263,7 +263,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
|
||||||
// append hyper parameter to specialized_parameter_list
|
// append hyper parameter to specialized_parameter_list
|
||||||
MS_EXCEPTION_IF_NULL(specialized_graph);
|
MS_EXCEPTION_IF_NULL(specialized_graph);
|
||||||
auto params = specialized_graph->parameters();
|
auto params = specialized_graph->parameters();
|
||||||
specialized_parameter_list.insert(specialized_parameter_list.end(), params.end() - SizeToInt(hyper_param_count_),
|
specialized_parameter_list.insert(specialized_parameter_list.end(), params.end() - SizeToInt(fv_param_count_),
|
||||||
params.end());
|
params.end());
|
||||||
std::vector<AnfNodePtr> specialized_parameter_list_update(specialized_parameter_list.begin() + pos_arg_indexes.size(),
|
std::vector<AnfNodePtr> specialized_parameter_list_update(specialized_parameter_list.begin() + pos_arg_indexes.size(),
|
||||||
specialized_parameter_list.end());
|
specialized_parameter_list.end());
|
||||||
|
|
|
@ -1512,7 +1512,7 @@ bool MSANFModelParser::MSANFParseModelConfigureInfo(const mind_ir::ModelProto &m
|
||||||
|
|
||||||
bool MSANFModelParser::SetValueForTopGraphParameter(const FuncGraphPtr &topGraph,
|
bool MSANFModelParser::SetValueForTopGraphParameter(const FuncGraphPtr &topGraph,
|
||||||
const std::map<std::string, ValuePtr> &weights) {
|
const std::map<std::string, ValuePtr> &weights) {
|
||||||
size_t hyper_param_count = 0;
|
size_t fv_param_count = 0;
|
||||||
auto parameters = topGraph->parameters();
|
auto parameters = topGraph->parameters();
|
||||||
for (int i = parameters.size() - 1; i >= 0; --i) {
|
for (int i = parameters.size() - 1; i >= 0; --i) {
|
||||||
size_t index = IntToSize(i);
|
size_t index = IntToSize(i);
|
||||||
|
@ -1536,9 +1536,9 @@ bool MSANFModelParser::SetValueForTopGraphParameter(const FuncGraphPtr &topGraph
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
parameter->set_default_param(weights_iter->second);
|
parameter->set_default_param(weights_iter->second);
|
||||||
hyper_param_count++;
|
fv_param_count++;
|
||||||
}
|
}
|
||||||
topGraph->set_hyper_param_count(hyper_param_count);
|
topGraph->set_fv_param_count(fv_param_count);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,290 @@
|
||||||
|
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
""" test outermost net pass non_tensor inputs"""
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor, Parameter, ParameterTuple
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
import mindspore.ops as ops
|
||||||
|
from mindspore import context
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def setup_teardown():
|
||||||
|
yield
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.add = P.TensorAdd()
|
||||||
|
self.sub = P.Sub()
|
||||||
|
|
||||||
|
def construct(self, tensor_param_x, tuple_a, list_b, tensor_param_y, tensor_param_z, dict_c):
|
||||||
|
out = self.add(tensor_param_x, tuple_a[0])
|
||||||
|
out = self.sub(out, list_b[1][1]["y"])
|
||||||
|
out = self.add(out, tensor_param_y)
|
||||||
|
out = self.sub(out, tensor_param_z)
|
||||||
|
out = self.add(out, dict_c["u"])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class GradNet(nn.Cell):
|
||||||
|
def __init__(self, net, get_all):
|
||||||
|
super(GradNet, self).__init__()
|
||||||
|
self.forward_net = net
|
||||||
|
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
|
||||||
|
self.grad_all = C.GradOperation(get_all=get_all)
|
||||||
|
|
||||||
|
def construct(self, tuple_a, tensor_param_x, list_b, tensor_param_y, tensor_param_z, dict_c):
|
||||||
|
return self.grad_all(self.forward_net)(tuple_a, tensor_param_x, list_b, tensor_param_y, tensor_param_z, dict_c)
|
||||||
|
|
||||||
|
|
||||||
|
tensor_x = Tensor(np.ones((2, 2), np.float32))
|
||||||
|
tensor_y = Tensor(np.ones((2, 2), np.float32) * 2)
|
||||||
|
tensor_z = Tensor(np.ones((2, 2), np.float32) * 3)
|
||||||
|
tensor_w = Tensor(np.ones((2, 2), np.float32) * 4)
|
||||||
|
tensor_p = Tensor(np.ones((2, 2), np.float32) * 5)
|
||||||
|
tensor_u = Tensor(np.ones((2, 2), np.float32) * 6)
|
||||||
|
tuple_arg = (tensor_x, tensor_y, tensor_z, tensor_w)
|
||||||
|
list_arg = [[tensor_x, tensor_x], [[tensor_x, tensor_y], {"x": tensor_x, "y": tensor_y, "z": tensor_z, "p": tensor_p}]]
|
||||||
|
dict_arg = {"x": tensor_x, "y": tensor_y, "u": tensor_u}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||||
|
def test_non_tensor_inputs(mode):
|
||||||
|
"""
|
||||||
|
Feature: Construct()/ms_function input type with back propagate.
|
||||||
|
Description: Normal input type without tensor.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=mode)
|
||||||
|
# grad first input
|
||||||
|
grad_fist_input_tensor_net = GradNet(Net(), get_all=False)
|
||||||
|
ret = grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_p, dict_arg)
|
||||||
|
assert np.allclose(ret.asnumpy(), np.ones((2, 2), np.float32))
|
||||||
|
# grad all inputs
|
||||||
|
grad_all_input_tensor_net = GradNet(Net(), get_all=True)
|
||||||
|
ret_all = grad_all_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_p, dict_arg)
|
||||||
|
assert len(ret_all) == 3
|
||||||
|
assert np.allclose(ret_all[0].asnumpy(), np.ones((2, 2), np.float32))
|
||||||
|
assert np.allclose(ret_all[1].asnumpy(), np.ones((2, 2), np.float32))
|
||||||
|
assert np.allclose(ret_all[2].asnumpy(), np.ones((2, 2), np.float32) * -1)
|
||||||
|
|
||||||
|
|
||||||
|
class GradNet1(nn.Cell):
|
||||||
|
def __init__(self, net, get_all):
|
||||||
|
super(GradNet1, self).__init__()
|
||||||
|
self.forward_net = net
|
||||||
|
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
|
||||||
|
self.grad_all = C.GradOperation(get_all=get_all)
|
||||||
|
|
||||||
|
def construct(self, tuple_a, tensor_a, list_b, tensor_b, tensor_c, dict_c):
|
||||||
|
return self.grad_all(self.forward_net)(tuple_a, tensor_a, list_b, tensor_b, tensor_c, dict_c)
|
||||||
|
|
||||||
|
|
||||||
|
# PyNative run error.
|
||||||
|
# Support context.PYNATIVE_MODE later.
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
||||||
|
def test_grad_first_input_net(mode):
|
||||||
|
"""
|
||||||
|
Feature: Construct()/ms_function input type with back propagate.
|
||||||
|
Description: Normal input type.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
class FirstInputTensorNet(nn.Cell):
|
||||||
|
def construct(self, tensor_a, tuple_a, list_b, tensor_b, tensor_c, dict_c):
|
||||||
|
return tensor_a + tuple_a[2] - list_b[1][1]["y"] + tensor_b - tensor_c + dict_c["y"]
|
||||||
|
|
||||||
|
context.set_context(mode=mode)
|
||||||
|
grad_fist_input_tensor_net = GradNet1(FirstInputTensorNet(), get_all=False)
|
||||||
|
res = grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_y, dict_arg)
|
||||||
|
print('res:', res)
|
||||||
|
assert np.allclose(res.asnumpy(), np.ones((2, 2), np.float32))
|
||||||
|
|
||||||
|
|
||||||
|
class TestCell(nn.Cell):
|
||||||
|
def __init__(self, param):
|
||||||
|
super().__init__()
|
||||||
|
self.a = Tensor(np.array([[1, 2], [3, 4]]))
|
||||||
|
self.param = param
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.a * self.param * x
|
||||||
|
|
||||||
|
|
||||||
|
class GradCellWithParameter(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super().__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad = ops.GradOperation(get_all=True, get_by_list=True)
|
||||||
|
self.param = self.net.param
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.grad(self.net, self.param)(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GradCell(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super().__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad_all = ops.GradOperation(get_all=True)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.grad_all(self.net)(x)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||||
|
def test_grad_parameter_input(mode):
|
||||||
|
"""
|
||||||
|
Feature: Construct()/ms_function input type with back propagate.
|
||||||
|
Description: Grad with Parameter as input type.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=mode)
|
||||||
|
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
|
||||||
|
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
|
||||||
|
z = Tensor(np.array([[7, 8], [9, 0]]))
|
||||||
|
a = GradCell(TestCell(x))(y)
|
||||||
|
b = GradCell(TestCell(x))(z)
|
||||||
|
print(f'a: {a}')
|
||||||
|
print(f'b: {b}')
|
||||||
|
assert np.array_equal(a[0].asnumpy(), b[0].asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
# PyNative run error.
|
||||||
|
# Support context.PYNATIVE_MODE later.
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
||||||
|
def test_grad_parameter_as_input_and_fv(mode):
|
||||||
|
"""
|
||||||
|
Feature: Construct()/ms_function input type with back propagate.
|
||||||
|
Description: Grad with Parameters as input type and fv.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=mode)
|
||||||
|
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
|
||||||
|
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
|
||||||
|
z = Tensor(np.array([[7, 8], [9, 0]]))
|
||||||
|
a = GradCellWithParameter(TestCell(x))(y)
|
||||||
|
b = GradCellWithParameter(TestCell(x))(z)
|
||||||
|
print(f'a: {a}')
|
||||||
|
print(f'b: {b}')
|
||||||
|
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
|
||||||
|
assert np.array_equal(a[1].asnumpy(), b[1].asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
# PyNative run error.
|
||||||
|
# Support context.PYNATIVE_MODE later.
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
||||||
|
def test_grad_same_parameter_both_input_and_fv(mode):
|
||||||
|
"""
|
||||||
|
Feature: Construct()/ms_function input type with back propagate.
|
||||||
|
Description: Grad with the same Parameter used as input type and fv at the same time.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=mode)
|
||||||
|
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
|
||||||
|
y = Tensor(np.array([[1, 2], [3, 4]]))
|
||||||
|
a = GradCellWithParameter(TestCell(x))(x)
|
||||||
|
b = GradCellWithParameter(TestCell(x))(y)
|
||||||
|
print(f'a: {a}')
|
||||||
|
print(f'b: {b}')
|
||||||
|
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
|
||||||
|
assert np.array_equal(a[1].asnumpy(), b[1].asnumpy())
|
||||||
|
|
||||||
|
|
||||||
|
class TestCell2(nn.Cell):
|
||||||
|
def __init__(self, param1, param2):
|
||||||
|
super().__init__()
|
||||||
|
self.a = Tensor(np.array([[1, 2], [3, 4]]))
|
||||||
|
self.param1 = param1
|
||||||
|
self.param2 = param2
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.a * self.param1 * self.param2 * x
|
||||||
|
|
||||||
|
|
||||||
|
class GradCellWithParameterTuple(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super().__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad = ops.GradOperation(get_all=True, get_by_list=True)
|
||||||
|
self.param1 = self.net.param1
|
||||||
|
self.param2 = self.net.param2
|
||||||
|
self.params = ParameterTuple([self.param1, self.param2])
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.grad(self.net, self.params)(x)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level1
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||||
|
def test_grad_parameter_as_input_and_fv2(mode):
|
||||||
|
"""
|
||||||
|
Feature: Construct()/ms_function input type with back propagate.
|
||||||
|
Description: Grad with Parameters as input type and fv. ParameterTuple as fv.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=mode)
|
||||||
|
x1 = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x1')
|
||||||
|
x2 = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x2')
|
||||||
|
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
|
||||||
|
z = Tensor(np.array([[7, 8], [9, 0]]))
|
||||||
|
a = GradCellWithParameterTuple(TestCell2(x1, x2))(y)
|
||||||
|
b = GradCellWithParameterTuple(TestCell2(x1, x2))(z)
|
||||||
|
print(f'a: {a}')
|
||||||
|
print(f'b: {b}')
|
||||||
|
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
|
||||||
|
assert np.array_equal(a[1][0].asnumpy(), b[1][0].asnumpy())
|
||||||
|
assert np.array_equal(a[1][1].asnumpy(), b[1][1].asnumpy())
|
|
@ -1,82 +0,0 @@
|
||||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
""" test outermost net pass non_tensor inputs"""
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import mindspore.nn as nn
|
|
||||||
from mindspore import Tensor
|
|
||||||
from mindspore.ops import composite as C
|
|
||||||
from mindspore.ops import operations as P
|
|
||||||
from mindspore import context
|
|
||||||
|
|
||||||
context.set_context(mode=context.PYNATIVE_MODE)
|
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
self.add = P.TensorAdd()
|
|
||||||
self.sub = P.Sub()
|
|
||||||
|
|
||||||
def construct(self, tensor_x, tuple_a, list_b, tensor_y, tensor_z, dict_c):
|
|
||||||
out = self.add(tensor_x, tuple_a[0])
|
|
||||||
out = self.sub(out, list_b[1][1]["y"])
|
|
||||||
out = self.add(out, tensor_y)
|
|
||||||
out = self.sub(out, tensor_z)
|
|
||||||
out = self.add(out, dict_c["u"])
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class GradNet(nn.Cell):
|
|
||||||
def __init__(self, net, get_all):
|
|
||||||
super(GradNet, self).__init__()
|
|
||||||
self.forward_net = net
|
|
||||||
self.sens = Tensor(np.ones((2, 2), np.float32) * 5)
|
|
||||||
self.grad_all = C.GradOperation(get_all=get_all)
|
|
||||||
|
|
||||||
def construct(self, tuple_a, tensor_x, list_b, tensor_y, tensor_z, dict_c):
|
|
||||||
return self.grad_all(self.forward_net)(tuple_a, tensor_x, list_b, tensor_y, tensor_z, dict_c)
|
|
||||||
|
|
||||||
|
|
||||||
x = Tensor(np.ones((2, 2), np.float32))
|
|
||||||
y = Tensor(np.ones((2, 2), np.float32) * 2)
|
|
||||||
z = Tensor(np.ones((2, 2), np.float32) * 3)
|
|
||||||
w = Tensor(np.ones((2, 2), np.float32) * 4)
|
|
||||||
p = Tensor(np.ones((2, 2), np.float32) * 5)
|
|
||||||
u = Tensor(np.ones((2, 2), np.float32) * 6)
|
|
||||||
arg_t0 = (x, y, z, w)
|
|
||||||
arg_l0 = [[x, x], [[x, y], {"x": x, "y": y, "z": z, "p": p}]]
|
|
||||||
args_d0 = {"x": x, "y": y, "u": u}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level1
|
|
||||||
@pytest.mark.platform_x86_cpu
|
|
||||||
@pytest.mark.platform_arm_ascend_training
|
|
||||||
@pytest.mark.platform_x86_ascend_training
|
|
||||||
@pytest.mark.platform_x86_gpu_training
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
def test_non_tensor_inputs():
|
|
||||||
# grad first input
|
|
||||||
grad_fist_input_tensor_net = GradNet(Net(), get_all=False)
|
|
||||||
ret = grad_fist_input_tensor_net(z, arg_t0, arg_l0, w, p, args_d0)
|
|
||||||
assert np.allclose(ret.asnumpy(), np.ones((2, 2), np.float32))
|
|
||||||
# grad all inputs
|
|
||||||
grad_all_input_tensor_net = GradNet(Net(), get_all=True)
|
|
||||||
ret_all = grad_all_input_tensor_net(z, arg_t0, arg_l0, w, p, args_d0)
|
|
||||||
assert len(ret_all) == 3
|
|
||||||
assert np.allclose(ret_all[0].asnumpy(), np.ones((2, 2), np.float32))
|
|
||||||
assert np.allclose(ret_all[1].asnumpy(), np.ones((2, 2), np.float32))
|
|
||||||
assert np.allclose(ret_all[2].asnumpy(), np.ones((2, 2), np.float32) * -1)
|
|
|
@ -36,7 +36,7 @@ constexpr auto kDependRealInputSize = 2;
|
||||||
ParameterPtr TestCreateParameter(const KernelGraphPtr &g, const std::string &name,
|
ParameterPtr TestCreateParameter(const KernelGraphPtr &g, const std::string &name,
|
||||||
const abstract::AbstractBasePtr &abstract) {
|
const abstract::AbstractBasePtr &abstract) {
|
||||||
MS_EXCEPTION_IF_NULL(g);
|
MS_EXCEPTION_IF_NULL(g);
|
||||||
auto parameter = g->AddWeightParameter(name);
|
auto parameter = g->AddFvParameter(name, abstract->BuildValue());
|
||||||
if (parameter == nullptr) {
|
if (parameter == nullptr) {
|
||||||
MS_LOG(ERROR) << "Cannot add weight parameter!";
|
MS_LOG(ERROR) << "Cannot add weight parameter!";
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,9 +19,9 @@ import pytest
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore.common import mutable
|
from mindspore.common import mutable
|
||||||
from mindspore import Tensor, Parameter, ParameterTuple
|
from mindspore import Tensor, Parameter, ParameterTuple
|
||||||
from mindspore import context
|
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
import mindspore.ops as ops
|
import mindspore.ops as ops
|
||||||
|
from mindspore import context
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=True)
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
@ -91,29 +91,7 @@ def test_grad_first_input_net(mode):
|
||||||
|
|
||||||
context.set_context(mode=mode)
|
context.set_context(mode=mode)
|
||||||
grad_fist_input_tensor_net = GradNet1(FirstInputTensorNet(), get_all=False)
|
grad_fist_input_tensor_net = GradNet1(FirstInputTensorNet(), get_all=False)
|
||||||
res = grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_y, dict_arg)
|
grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_y, dict_arg)
|
||||||
print('res:', res)
|
|
||||||
assert np.allclose(res.asnumpy(), np.ones((2, 2), np.float32))
|
|
||||||
|
|
||||||
|
|
||||||
# PyNative run error.
|
|
||||||
# Support context.PYNATIVE_MODE later.
|
|
||||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
|
||||||
def test_grad_first_input_net_pynative_error(mode):
|
|
||||||
"""
|
|
||||||
Feature: Construct()/ms_function input type with back propagate.
|
|
||||||
Description: Normal input type.
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
class FirstInputTensorNet(nn.Cell):
|
|
||||||
def construct(self, tensor_a, tuple_a, list_b, tensor_b, tensor_c, dict_c):
|
|
||||||
return tensor_a + tuple_a[2] - list_b[1][1]["y"] + tensor_b - tensor_c + dict_c["y"]
|
|
||||||
|
|
||||||
context.set_context(mode=mode)
|
|
||||||
grad_fist_input_tensor_net = GradNet1(FirstInputTensorNet(), get_all=False)
|
|
||||||
res = grad_fist_input_tensor_net(tensor_z, tuple_arg, list_arg, tensor_w, tensor_y, dict_arg)
|
|
||||||
print('res:', res)
|
|
||||||
assert np.allclose(res.asnumpy(), np.ones((2, 2), np.float32))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
|
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
|
||||||
|
@ -149,7 +127,6 @@ def test_outermost_net_pass_parameter(mode):
|
||||||
|
|
||||||
|
|
||||||
# Support the Parameter as outermost input.
|
# Support the Parameter as outermost input.
|
||||||
# Support context.PYNATIVE_MODE UT later.
|
|
||||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
||||||
def test_outermost_net_pass_tuple_including_parameter(mode):
|
def test_outermost_net_pass_tuple_including_parameter(mode):
|
||||||
"""
|
"""
|
||||||
|
@ -163,7 +140,6 @@ def test_outermost_net_pass_tuple_including_parameter(mode):
|
||||||
|
|
||||||
|
|
||||||
# Support the Parameter as outermost input.
|
# Support the Parameter as outermost input.
|
||||||
# Support context.PYNATIVE_MODE UT later.
|
|
||||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
||||||
def test_outermost_net_pass_list_including_parameter(mode):
|
def test_outermost_net_pass_list_including_parameter(mode):
|
||||||
"""
|
"""
|
||||||
|
@ -177,7 +153,6 @@ def test_outermost_net_pass_list_including_parameter(mode):
|
||||||
|
|
||||||
|
|
||||||
# Support the Parameter as outermost input.
|
# Support the Parameter as outermost input.
|
||||||
# Support context.PYNATIVE_MODE UT later.
|
|
||||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
||||||
def test_grad_net_pass_dict_including_parameter(mode):
|
def test_grad_net_pass_dict_including_parameter(mode):
|
||||||
"""
|
"""
|
||||||
|
@ -190,96 +165,6 @@ def test_grad_net_pass_dict_including_parameter(mode):
|
||||||
forward_net(tuple_arg, tensor_z, list_arg, SCALAR_NUM, SCALAR_NUM, mutable_dict, flag_0)
|
forward_net(tuple_arg, tensor_z, list_arg, SCALAR_NUM, SCALAR_NUM, mutable_dict, flag_0)
|
||||||
|
|
||||||
|
|
||||||
class TestCell(nn.Cell):
|
|
||||||
def __init__(self, param):
|
|
||||||
super().__init__()
|
|
||||||
self.a = Tensor(np.array([[1, 2], [3, 4]]))
|
|
||||||
self.param = param
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
return self.a * self.param * x
|
|
||||||
|
|
||||||
|
|
||||||
class GradCellWithParameter(nn.Cell):
|
|
||||||
def __init__(self, net):
|
|
||||||
super().__init__()
|
|
||||||
self.net = net
|
|
||||||
self.grad = ops.GradOperation(get_all=True, get_by_list=True)
|
|
||||||
self.param = self.net.param
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
return self.grad(self.net, self.param)(x)
|
|
||||||
|
|
||||||
|
|
||||||
class GradCell(nn.Cell):
|
|
||||||
def __init__(self, net):
|
|
||||||
super().__init__()
|
|
||||||
self.net = net
|
|
||||||
self.grad_all = ops.GradOperation(get_all=True)
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
return self.grad_all(self.net)(x)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
|
|
||||||
def test_grad_parameter_input(mode):
|
|
||||||
"""
|
|
||||||
Feature: Construct()/ms_function input type with back propagate.
|
|
||||||
Description: Grad with Parameter as input type.
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
context.set_context(mode=mode)
|
|
||||||
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
|
|
||||||
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
|
|
||||||
z = Tensor(np.array([[7, 8], [9, 0]]))
|
|
||||||
a = GradCell(TestCell(x))(y)
|
|
||||||
b = GradCell(TestCell(x))(z)
|
|
||||||
print(f'a: {a}')
|
|
||||||
print(f'b: {b}')
|
|
||||||
assert np.array_equal(a[0].asnumpy(), b[0].asnumpy())
|
|
||||||
|
|
||||||
|
|
||||||
# PyNative run error.
|
|
||||||
# Support context.PYNATIVE_MODE later.
|
|
||||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
|
||||||
def test_grad_parameter_as_input_and_fv(mode):
|
|
||||||
"""
|
|
||||||
Feature: Construct()/ms_function input type with back propagate.
|
|
||||||
Description: Grad with Parameters as input type and fv.
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
context.set_context(mode=mode)
|
|
||||||
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
|
|
||||||
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
|
|
||||||
z = Tensor(np.array([[7, 8], [9, 0]]))
|
|
||||||
a = GradCellWithParameter(TestCell(x))(y)
|
|
||||||
b = GradCellWithParameter(TestCell(x))(z)
|
|
||||||
print(f'a: {a}')
|
|
||||||
print(f'b: {b}')
|
|
||||||
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
|
|
||||||
assert np.array_equal(a[1].asnumpy(), b[1].asnumpy())
|
|
||||||
|
|
||||||
|
|
||||||
# PyNative run error.
|
|
||||||
# Support context.PYNATIVE_MODE later.
|
|
||||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE])
|
|
||||||
def test_grad_same_parameter_both_input_and_fv(mode):
|
|
||||||
"""
|
|
||||||
Feature: Construct()/ms_function input type with back propagate.
|
|
||||||
Description: Grad with the same Parameter used as input type and fv at the same time.
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
context.set_context(mode=mode)
|
|
||||||
x = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x')
|
|
||||||
y = Tensor(np.array([[1, 2], [3, 4]]))
|
|
||||||
a = GradCellWithParameter(TestCell(x))(x)
|
|
||||||
b = GradCellWithParameter(TestCell(x))(y)
|
|
||||||
print(f'a: {a}')
|
|
||||||
print(f'b: {b}')
|
|
||||||
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
|
|
||||||
assert np.array_equal(a[1].asnumpy(), b[1].asnumpy())
|
|
||||||
|
|
||||||
|
|
||||||
class TestCell2(nn.Cell):
|
class TestCell2(nn.Cell):
|
||||||
def __init__(self, param1, param2):
|
def __init__(self, param1, param2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -329,7 +214,7 @@ class GradCellWithTupleOfParameter(nn.Cell):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
|
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
|
||||||
def test_grad_parameter_as_input_and_fv2(mode):
|
def test_grad_parameter_tuple(mode):
|
||||||
"""
|
"""
|
||||||
Feature: Construct()/ms_function input type with back propagate.
|
Feature: Construct()/ms_function input type with back propagate.
|
||||||
Description: Grad with Parameters as input type and fv. ParameterTuple as fv.
|
Description: Grad with Parameters as input type and fv. ParameterTuple as fv.
|
||||||
|
@ -340,13 +225,8 @@ def test_grad_parameter_as_input_and_fv2(mode):
|
||||||
x2 = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x2')
|
x2 = Parameter(Tensor(np.array([[1, 2], [3, 4]])), name='input_x2')
|
||||||
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
|
y = Parameter(Tensor(np.array([[7, 8], [9, 0]])), name='input_y')
|
||||||
z = Tensor(np.array([[7, 8], [9, 0]]))
|
z = Tensor(np.array([[7, 8], [9, 0]]))
|
||||||
a = GradCellWithParameterTuple(TestCell2(x1, x2))(y)
|
GradCellWithParameterTuple(TestCell2(x1, x2))(y)
|
||||||
b = GradCellWithParameterTuple(TestCell2(x1, x2))(z)
|
GradCellWithParameterTuple(TestCell2(x1, x2))(z)
|
||||||
print(f'a: {a}')
|
|
||||||
print(f'b: {b}')
|
|
||||||
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
|
|
||||||
assert np.array_equal(a[1][0].asnumpy(), b[1][0].asnumpy())
|
|
||||||
assert np.array_equal(a[1][1].asnumpy(), b[1][1].asnumpy())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason='Not support list or tuple of parameters as GradOperation inputs by now')
|
@pytest.mark.skip(reason='Not support list or tuple of parameters as GradOperation inputs by now')
|
||||||
|
|
Loading…
Reference in New Issue