forked from mindspore-Ecosystem/mindspore
!46993 support kwargs for top graph
Merge pull request !46993 from huanghui/support-kwargs
This commit is contained in:
commit
ce38106313
|
@ -86,8 +86,9 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_abs_l
|
|||
return res_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(key_value), dict_get_item});
|
||||
});
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The arguments of UnpackCall operator should be tuple, list or dict, but got "
|
||||
<< args_abs_list[index]->ToString();
|
||||
// For the mixed arguments: func(a, *args)
|
||||
AnfNodePtr param = res_graph->add_parameter();
|
||||
(void)elems.emplace_back(param);
|
||||
}
|
||||
}
|
||||
// Add to order list to trace if fn_node had side effect.
|
||||
|
|
|
@ -157,8 +157,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
py::arg("incremental") = py::bool_(false), py::arg("obf_ratio") = py::float_(1.0),
|
||||
py::arg("branch_control_input") = py::int_(0), "Get graph proto of dynamic-obfuscated model.")
|
||||
.def("get_params", &GraphExecutorPy::GetParams, py::arg("phase") = py::str(""), "Get Parameters from graph")
|
||||
.def("compile", &GraphExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""),
|
||||
py::arg("use_vm") = py::bool_(false), "Compile obj by executor.")
|
||||
.def("compile", &GraphExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("kwargs"),
|
||||
py::arg("phase") = py::str(""), py::arg("use_vm") = py::bool_(false), "Compile obj by executor.")
|
||||
.def("updata_param_node_default_input", &GraphExecutorPy::UpdataParamNodeDefaultInput, py::arg("phase"),
|
||||
py::arg("params"), "Fetch the inputs of Conv or Matmul for quant export.")
|
||||
.def("get_parameter_layout", &GraphExecutorPy::GetParameterLayout, py::arg("phase") = py::str("train"),
|
||||
|
|
|
@ -652,8 +652,8 @@ FunctionBlockPtr Parser::ParseDefFunction(const py::object &node, const Function
|
|||
// Save the function node to block
|
||||
func_block->WriteVariable(function_name, NewValueNode(current_fg));
|
||||
|
||||
py::object funcObj = python_adapter::GetPyObjAttr(node, "body");
|
||||
(void)ParseStatements(func_block, funcObj);
|
||||
py::object func_obj = python_adapter::GetPyObjAttr(node, "body");
|
||||
(void)ParseStatements(func_block, func_obj);
|
||||
|
||||
// Add unused variables as isolate nodes.
|
||||
for (auto &func_block_item : func_block_list_) {
|
||||
|
@ -1337,15 +1337,14 @@ void Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object
|
|||
values.push_back(ret_node);
|
||||
}
|
||||
}
|
||||
auto keys_tuple = GenerateMakeTuple(block, keys);
|
||||
auto values_tuple = GenerateMakeTuple(block, values);
|
||||
auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
|
||||
std::vector<AnfNodePtr> make_dict_nodes;
|
||||
make_dict_nodes.push_back(make_dict_op);
|
||||
make_dict_nodes.push_back(keys_tuple);
|
||||
make_dict_nodes.push_back(values_tuple);
|
||||
MS_EXCEPTION_IF_NULL(block->func_graph());
|
||||
args_context->packed_arguments.push_back(block->func_graph()->NewCNodeInOrder(std::move(make_dict_nodes)));
|
||||
if (!keys.empty()) {
|
||||
auto keys_tuple = GenerateMakeTuple(block, keys);
|
||||
auto values_tuple = GenerateMakeTuple(block, values);
|
||||
auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
|
||||
std::vector<AnfNodePtr> make_dict_nodes = {make_dict_op, keys_tuple, values_tuple};
|
||||
MS_EXCEPTION_IF_NULL(block->func_graph());
|
||||
args_context->packed_arguments.push_back(block->func_graph()->NewCNodeInOrder(std::move(make_dict_nodes)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3562,7 +3561,7 @@ FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
|
|||
param->set_name(name);
|
||||
MS_EXCEPTION_IF_NULL(param->debug_info());
|
||||
param->debug_info()->set_name(name);
|
||||
param->debug_info()->set_location(param->debug_info()->location());
|
||||
param->debug_info()->set_location(orig_param->debug_info()->location());
|
||||
param->set_is_top_graph_param(true);
|
||||
}
|
||||
func_graph->set_has_vararg(current_graph->has_vararg());
|
||||
|
|
|
@ -76,6 +76,7 @@
|
|||
#include "include/backend/data_queue/data_queue_mgr.h"
|
||||
#ifndef ENABLE_SECURITY
|
||||
#include "debug/data_dump/dump_json_parser.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#endif
|
||||
#if defined(__linux__) && defined(WITH_BACKEND)
|
||||
#include "ps/constants.h"
|
||||
|
@ -403,19 +404,18 @@ void CheckArgsValid(const py::object &source_obj, const py::tuple &args) {
|
|||
}
|
||||
}
|
||||
|
||||
py::object GraphExecutorPy::GenerateArgumentsKey(const py::object &obj, const py::tuple &args,
|
||||
py::object GraphExecutorPy::GenerateArgumentsKey(const py::object &obj, const py::tuple &args, const py::dict &kwargs,
|
||||
bool enable_tuple_broaden) {
|
||||
MS_LOG(DEBUG) << "GenerateArgumentsKey args size: " << args.size()
|
||||
<< ", enable_tuple_broaden: " << enable_tuple_broaden;
|
||||
|
||||
abstract::AbstractBasePtrList args_abs;
|
||||
cur_convert_input_.clear();
|
||||
std::size_t size = args.size();
|
||||
for (std::size_t i = 0; i < size; i++) {
|
||||
ClearCurConvertInput();
|
||||
for (std::size_t i = 0; i < args.size(); i++) {
|
||||
ValuePtr converted = nullptr;
|
||||
if (!parse::ConvertData(args[i], &converted)) {
|
||||
MS_EXCEPTION(TypeError) << "parse::ConvertData for " << i << "th argument failed, the argument type is "
|
||||
<< args[i].get_type() << ", value is '" << py::str(args[i]) << "'.";
|
||||
MS_LOG(EXCEPTION) << "parse::ConvertData for " << i << "th argument failed, the argument type is "
|
||||
<< args[i].get_type() << ", value is '" << py::str(args[i]) << "'.";
|
||||
}
|
||||
AbstractBasePtr abs = ArgsToAbstract(args[i], converted, enable_tuple_broaden);
|
||||
(void)args_abs.emplace_back(abs);
|
||||
|
@ -423,6 +423,20 @@ py::object GraphExecutorPy::GenerateArgumentsKey(const py::object &obj, const py
|
|||
// so we keep all inputs for subsequent procedure.
|
||||
(void)cur_convert_input_.emplace(args[i].ptr(), std::make_pair(converted, abs));
|
||||
}
|
||||
for (const auto &item : kwargs) {
|
||||
ValuePtr key = nullptr;
|
||||
ValuePtr value = nullptr;
|
||||
bool success = parse::ConvertData(py::cast<py::object>(item.first), &key) &&
|
||||
parse::ConvertData(py::cast<py::object>(item.second), &value);
|
||||
if (!success) {
|
||||
MS_LOG(EXCEPTION) << "parse::ConvertData for argument (" << py::str(item.first) << ": " << py::str(item.second)
|
||||
<< ") failed.";
|
||||
}
|
||||
AbstractBasePtr value_abs = ArgsToAbstract(py::cast<py::object>(item.second), value, enable_tuple_broaden);
|
||||
auto keyword_arg_abs = std::make_shared<abstract::AbstractKeywordArg>(GetValue<std::string>(key), value_abs);
|
||||
(void)args_abs.emplace_back(keyword_arg_abs);
|
||||
(void)cur_convert_input_.emplace(item.first.ptr(), std::make_pair(value, keyword_arg_abs));
|
||||
}
|
||||
|
||||
// If cache matched no need CheckArgsValid
|
||||
auto iter = kArgsCache.find(args_abs);
|
||||
|
@ -867,8 +881,8 @@ void GraphExecutorPy::CleanCompileRes(const ResourcePtr &resource) {
|
|||
MS_LOG(INFO) << "Clean compile resource end";
|
||||
}
|
||||
|
||||
bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple &args, const py::object &phase_obj,
|
||||
bool use_vm) {
|
||||
bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple &args, const py::dict &kwargs,
|
||||
const py::object &phase_obj, bool use_vm) {
|
||||
// Check if the phase is valid.
|
||||
if ((!py::isinstance<py::str>(phase_obj))) {
|
||||
MS_LOG(ERROR) << "The `phase` must be string.";
|
||||
|
@ -887,7 +901,8 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
|
|||
phase_ = phase;
|
||||
auto obj_desc = GetObjDesc(source_obj);
|
||||
MS_LOG(INFO) << "Start compiling, phase: " << phase;
|
||||
MS_LOG(DEBUG) << "source: {" << py::str(source_obj) << "}\nargs: " << py::str(const_cast<py::tuple &>(args));
|
||||
MS_LOG(DEBUG) << "source: {" << py::str(source_obj) << "}\nargs: " << py::str(const_cast<py::tuple &>(args))
|
||||
<< "\nkwargs: " << py::str(const_cast<py::dict &>(kwargs));
|
||||
EventMessage::PrintCompileStartMsg(phase, obj_desc);
|
||||
|
||||
ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
|
||||
|
@ -914,42 +929,15 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
|
|||
// Get the parameters items and add the value to args_abs.
|
||||
abstract::AbstractBasePtrList args_abs;
|
||||
std::vector<ValuePtr> arguments;
|
||||
std::size_t size = args.size();
|
||||
MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
|
||||
bool is_parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kSemiAutoParallel ||
|
||||
parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kAutoParallel;
|
||||
bool is_auto_parallel = is_parallel_mode && !py::hasattr(source_obj, parallel::kSkipAutoParallelCompile) &&
|
||||
!py::hasattr(source_obj, parallel::kKeepInputUnchanged);
|
||||
for (std::size_t i = 0; i < size; i++) {
|
||||
ValuePtr converted = nullptr;
|
||||
// In some parallel mode need full_tensor which cause the args of GenerateArgumentsKey not same to compile,
|
||||
// So can't use cur_convert_input_ directly.
|
||||
auto iter = cur_convert_input_.find(args[i].ptr());
|
||||
if (iter != cur_convert_input_.end()) {
|
||||
(void)arguments.emplace_back(iter->second.first);
|
||||
if (is_auto_parallel) {
|
||||
auto abs_item = iter->second.second->Clone();
|
||||
(void)parallel::ExtendInputArgsAbstractShape(abs_item, i);
|
||||
(void)args_abs.emplace_back(abs_item);
|
||||
continue;
|
||||
}
|
||||
(void)args_abs.emplace_back(iter->second.second);
|
||||
continue;
|
||||
}
|
||||
bool succ = parse::ConvertData(args[i], &converted);
|
||||
if (!succ) {
|
||||
MS_LOG(EXCEPTION) << "Fail to convert the " << i << "th argument, args[" << i << "]: " << py::str(args[i]);
|
||||
}
|
||||
(void)arguments.emplace_back(converted);
|
||||
auto args_abstract_item = ArgsToAbstract(args[i], converted, enable_tuple_broaden_);
|
||||
if (is_auto_parallel) {
|
||||
(void)parallel::ExtendInputArgsAbstractShape(args_abstract_item, i);
|
||||
}
|
||||
(void)args_abs.emplace_back(args_abstract_item);
|
||||
}
|
||||
ConvertArgs(args, kwargs, is_auto_parallel, &args_abs, &arguments);
|
||||
resource->set_arguments(arguments);
|
||||
resource->set_args_abs(args_abs);
|
||||
executor_info->arg_list_size = size;
|
||||
executor_info->arg_list_size = args.size() + kwargs.size();
|
||||
executor_info->resource = resource;
|
||||
info_[phase] = executor_info;
|
||||
pip->Run();
|
||||
|
@ -969,6 +957,59 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
|
|||
return true;
|
||||
}
|
||||
|
||||
void GraphExecutorPy::ConvertArgs(const py::tuple &args, const py::dict &kwargs, bool is_auto_parallel,
|
||||
abstract::AbstractBasePtrList *args_abs, std::vector<ValuePtr> *arguments) {
|
||||
MS_EXCEPTION_IF_NULL(args_abs);
|
||||
MS_EXCEPTION_IF_NULL(arguments);
|
||||
for (std::size_t i = 0; i < args.size(); i++) {
|
||||
// In some parallel mode need full_tensor which cause the args of GenerateArgumentsKey not same to compile,
|
||||
// So can't use cur_convert_input_ directly.
|
||||
auto iter = cur_convert_input_.find(args[i].ptr());
|
||||
if (iter != cur_convert_input_.end()) {
|
||||
(void)arguments->emplace_back(iter->second.first);
|
||||
if (is_auto_parallel) {
|
||||
auto abs_item = iter->second.second->Clone();
|
||||
(void)parallel::ExtendInputArgsAbstractShape(abs_item, i);
|
||||
(void)args_abs->emplace_back(abs_item);
|
||||
continue;
|
||||
}
|
||||
(void)args_abs->emplace_back(iter->second.second);
|
||||
continue;
|
||||
}
|
||||
ValuePtr converted = nullptr;
|
||||
bool success = parse::ConvertData(args[i], &converted);
|
||||
if (!success) {
|
||||
MS_LOG(EXCEPTION) << "Fail to convert the " << i << "th argument, args[" << i << "]: " << py::str(args[i]);
|
||||
}
|
||||
(void)arguments->emplace_back(converted);
|
||||
auto args_abstract_item = ArgsToAbstract(args[i], converted, enable_tuple_broaden_);
|
||||
if (is_auto_parallel) {
|
||||
(void)parallel::ExtendInputArgsAbstractShape(args_abstract_item, i);
|
||||
}
|
||||
(void)args_abs->emplace_back(args_abstract_item);
|
||||
}
|
||||
for (const auto &item : kwargs) {
|
||||
auto iter = cur_convert_input_.find(item.first.ptr());
|
||||
if (iter != cur_convert_input_.end()) {
|
||||
(void)arguments->emplace_back(iter->second.first);
|
||||
(void)args_abs->emplace_back(iter->second.second);
|
||||
continue;
|
||||
}
|
||||
ValuePtr key = nullptr;
|
||||
ValuePtr value = nullptr;
|
||||
bool success = parse::ConvertData(py::cast<py::object>(item.first), &key) &&
|
||||
parse::ConvertData(py::cast<py::object>(item.second), &value);
|
||||
if (!success) {
|
||||
MS_LOG(EXCEPTION) << "Fail to convert the argument (" << py::str(item.first) << ": " << py::str(item.second)
|
||||
<< ").";
|
||||
}
|
||||
AbstractBasePtr value_abs = ArgsToAbstract(py::cast<py::object>(item.second), value, enable_tuple_broaden_);
|
||||
auto keyword_arg_abs = std::make_shared<abstract::AbstractKeywordArg>(GetValue<std::string>(key), value_abs);
|
||||
(void)arguments->emplace_back(value);
|
||||
(void)args_abs->emplace_back(keyword_arg_abs);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ActionItem> GraphExecutorPy::FilterActions(const std::vector<ActionItem> &actions,
|
||||
const std::string &phase) {
|
||||
// filter action after validate when 'export'.
|
||||
|
@ -1002,12 +1043,12 @@ void GraphExecutorPy::ReleaseResource(const py::object &phase) {
|
|||
}
|
||||
}
|
||||
|
||||
bool GraphExecutorPy::Compile(const py::object &source_obj, const py::tuple &args, const py::object &phase,
|
||||
bool use_vm) {
|
||||
bool GraphExecutorPy::Compile(const py::object &source_obj, const py::tuple &args, const py::dict &kwargs,
|
||||
const py::object &phase, bool use_vm) {
|
||||
bool ret_value = false;
|
||||
try {
|
||||
ProcessStatus::GetInstance().RecordStart("CompileInner");
|
||||
ret_value = CompileInner(source_obj, args, phase, use_vm);
|
||||
ret_value = CompileInner(source_obj, args, kwargs, phase, use_vm);
|
||||
ProcessStatus::GetInstance().RecordEnd();
|
||||
ProcessStatus::GetInstance().Print();
|
||||
} catch (const py::error_already_set &ex) {
|
||||
|
@ -1277,9 +1318,8 @@ bool Pipeline::NeedCreateBackend() {
|
|||
|
||||
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) {
|
||||
MS_EXCEPTION_IF_NULL(arg_list);
|
||||
std::size_t size = args.size();
|
||||
bool arg_list_inited = !arg_list->empty();
|
||||
for (std::size_t i = 0; i < size; i++) {
|
||||
for (std::size_t i = 0; i < args.size(); i++) {
|
||||
py::object arg = args[i];
|
||||
ValuePtr converted = nullptr;
|
||||
bool succ = parse::ConvertData(arg, &converted);
|
||||
|
|
|
@ -77,8 +77,12 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
|
|||
const std::string &phase() const { return phase_; }
|
||||
const std::map<std::string, std::string> &jit_config() const { return jit_config_; }
|
||||
void SaveCompiledGraph(const std::string &phase);
|
||||
bool CompileInner(const py::object &source_obj, const py::tuple &args, const py::object &phase_obj, bool use_vm);
|
||||
bool Compile(const py::object &source_obj, const py::tuple &args, const py::object &phase, bool use_vm);
|
||||
void ConvertArgs(const py::tuple &args, const py::dict &kwargs, bool is_auto_parallel,
|
||||
abstract::AbstractBasePtrList *args_abs, std::vector<ValuePtr> *arguments);
|
||||
bool CompileInner(const py::object &source_obj, const py::tuple &args, const py::dict &kwargs,
|
||||
const py::object &phase_obj, bool use_vm);
|
||||
bool Compile(const py::object &source_obj, const py::tuple &args, const py::dict &kwargs, const py::object &phase,
|
||||
bool use_vm);
|
||||
|
||||
void ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *const arg_list);
|
||||
|
||||
|
@ -131,7 +135,8 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
|
|||
#endif
|
||||
|
||||
// Generate a key for mapping function graph
|
||||
py::object GenerateArgumentsKey(const py::object &obj, const py::tuple &args, bool enable_tuple_broaden = false);
|
||||
py::object GenerateArgumentsKey(const py::object &obj, const py::tuple &args, const py::dict &kwargs,
|
||||
bool enable_tuple_broaden = false);
|
||||
|
||||
void ClearCurConvertInput();
|
||||
|
||||
|
|
|
@ -209,10 +209,11 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP
|
|||
return *key_value == *item.first->BuildValue();
|
||||
});
|
||||
if (it == dict_elems.end()) {
|
||||
// For dict[key], if key is not exist, will raise a KeyError exception.
|
||||
// For dict[key], if key is not exist, will raise a ValueError exception.
|
||||
// For dict.get('key', default=None), if key is not exist, will return the default value during dict_get.
|
||||
MS_EXCEPTION(KeyError) << "The key " << key_value->ToString()
|
||||
<< " does not exist in the dict:" << args_spec_list[0]->BuildValue()->ToString();
|
||||
// Python KeyError will print escape character. So use ValueError instead of KeyError here.
|
||||
MS_EXCEPTION(ValueError) << "The key " << key_value->ToString()
|
||||
<< " does not exist in the dict:" << args_spec_list[0]->BuildValue()->ToString();
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ std::string Keyword::ToString() const {
|
|||
MS_EXCEPTION_IF_NULL(value_);
|
||||
buffer << "Keyword[";
|
||||
buffer << "key : " << key_;
|
||||
buffer << "value : " << value_->ToString();
|
||||
buffer << ", value : " << value_->ToString();
|
||||
buffer << "]";
|
||||
}
|
||||
return buffer.str();
|
||||
|
|
|
@ -303,7 +303,7 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
|
||||
|
||||
void GenerateKwParams(const FuncGraphPtr &specialized_graph,
|
||||
const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list,
|
||||
const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list, int pos_args_input_count,
|
||||
std::vector<AnfNodePtr> *specialized_parameter_list,
|
||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
|
||||
|
||||
|
|
|
@ -122,25 +122,30 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, int var
|
|||
|
||||
void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
|
||||
const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list,
|
||||
std::vector<AnfNodePtr> *specialized_parameter_list,
|
||||
int pos_args_input_count, std::vector<AnfNodePtr> *specialized_parameter_list,
|
||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
|
||||
MS_EXCEPTION_IF_NULL(specialized_parameter_list);
|
||||
MS_EXCEPTION_IF_NULL(repl_nodes);
|
||||
MS_EXCEPTION_IF_NULL(specialized_graph);
|
||||
std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
|
||||
std::set<AnfNodePtr> kwarg_nodes;
|
||||
for (const auto &kwarg : kwarg_list) {
|
||||
for (size_t i = 0; i < kwarg_list.size(); ++i) {
|
||||
auto kwarg = kwarg_list[i];
|
||||
MS_EXCEPTION_IF_NULL(kwarg);
|
||||
std::string kw_param_name = kwarg->get_key();
|
||||
MS_EXCEPTION_IF_NULL(specialized_graph);
|
||||
AnfNodePtr param_node = specialized_graph->GetParameterByName(kw_param_name);
|
||||
// If not find corresponding parameter node.
|
||||
if (param_node == nullptr) {
|
||||
if (!has_kwarg()) {
|
||||
MS_LOG(EXCEPTION) << "Got unexpected keyword argument: " << kw_param_name;
|
||||
if (pos_args_input_count + i > specialized_graph->parameters().size() - 1) {
|
||||
MS_LOG(EXCEPTION) << "Got unexpected keyword argument: " << kw_param_name;
|
||||
}
|
||||
specialized_parameter_list->push_back(specialized_graph->parameters()[pos_args_input_count + i]);
|
||||
} else {
|
||||
ParameterPtr para = std::make_shared<Parameter>(specialized_graph);
|
||||
std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]";
|
||||
MS_EXCEPTION_IF_NULL(specialized_parameter_list);
|
||||
auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(),
|
||||
[param_name](const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -169,7 +174,6 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
|
|||
auto extract_node = specialized_graph->NewCNode(
|
||||
{NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node});
|
||||
kwarg_nodes.insert(param_node);
|
||||
MS_EXCEPTION_IF_NULL(repl_nodes);
|
||||
(void)repl_nodes->emplace(param_node, extract_node);
|
||||
}
|
||||
}
|
||||
|
@ -182,7 +186,7 @@ void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph,
|
|||
const std::vector<AnfNodePtr> &kwarg_keys_tuple_nodes,
|
||||
const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes,
|
||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
|
||||
if (has_kwarg()) {
|
||||
if (has_kwarg() && !kwarg_keys_tuple_nodes.empty()) {
|
||||
MS_EXCEPTION_IF_NULL(specialized_graph);
|
||||
TraceGuard guard(
|
||||
std::make_shared<TraceGenerateKwArg>(specialized_graph->GetVariableKwargParameter()->debug_info()));
|
||||
|
@ -264,7 +268,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
|
|||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> repl_nodes;
|
||||
GenerateVarParams(specialized_graph, variable_args_count, pos_args_input_count, &specialized_parameter_list,
|
||||
&repl_nodes);
|
||||
GenerateKwParams(specialized_graph, kwarg_list, &specialized_parameter_list, &repl_nodes);
|
||||
GenerateKwParams(specialized_graph, kwarg_list, pos_args_input_count, &specialized_parameter_list, &repl_nodes);
|
||||
|
||||
GenerateDefaultValue(specialized_graph, specialized_parameter_list, &repl_nodes);
|
||||
|
||||
|
|
|
@ -120,28 +120,30 @@ def _handle_func_args(func, *args, **kwargs):
|
|||
bound_arguments.apply_defaults()
|
||||
args = bound_arguments.args
|
||||
kwargs = bound_arguments.kwargs
|
||||
# After apply_defaults, kwargs should be empty here.
|
||||
if kwargs:
|
||||
raise ValueError(f"Failed to handle kwargs of {func.__name__}. Maybe you pass wrong arguments, "
|
||||
f"or there is a key in kwargs that is not used as a function argument, "
|
||||
f"args: {args}, kwargs: {kwargs}")
|
||||
|
||||
positional_args = 0
|
||||
default_args = 0
|
||||
has_var = False
|
||||
for value in inspect.signature(func).parameters.values():
|
||||
if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD:
|
||||
return args
|
||||
has_var = True
|
||||
if value.kind is inspect.Parameter.KEYWORD_ONLY:
|
||||
raise TypeError(f"Function {func.__name__}, MindSpore does not support keyword-only arg: {value}.")
|
||||
if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
||||
if value.default is inspect.Parameter.empty:
|
||||
positional_args += 1
|
||||
else:
|
||||
default_args += 1
|
||||
|
||||
if has_var:
|
||||
return args, kwargs
|
||||
|
||||
if len(args) < positional_args:
|
||||
raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument, but got {len(args)}.")
|
||||
if len(args) > positional_args + default_args:
|
||||
raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument and {default_args} "
|
||||
f"default argument, total {positional_args + default_args}, but got {len(args)}.")
|
||||
return args
|
||||
return args, kwargs
|
||||
|
||||
|
||||
sys_path = list(sys.path)
|
||||
|
@ -240,25 +242,39 @@ def _get_parameter_layout():
|
|||
return layout
|
||||
|
||||
|
||||
def _get_args_for_run(obj, args_list):
|
||||
"""Get the actual input args for runtime."""
|
||||
inputs = []
|
||||
for i in args_list:
|
||||
if isinstance(i, PythonTensor):
|
||||
if i.has_init:
|
||||
i.init_data()
|
||||
if not i.const_arg:
|
||||
inputs.append(i)
|
||||
elif isinstance(i, (Tensor, CSRTensor, COOTensor)):
|
||||
inputs.append(i)
|
||||
elif hasattr(i, "__ms_mutable__") and getattr(i, "__ms_mutable__"):
|
||||
inputs.append(i)
|
||||
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
|
||||
inputs.append(i)
|
||||
elif hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(i, tuple) and \
|
||||
_check_all_tensor(i):
|
||||
inputs.append(i)
|
||||
return inputs
|
||||
def _handle_arg(obj, arg):
|
||||
"""Handle arg for runtime .If need handle the arg, return True"""
|
||||
if isinstance(arg, PythonTensor):
|
||||
if arg.has_init:
|
||||
arg.init_data()
|
||||
if not arg.const_arg:
|
||||
return arg
|
||||
elif isinstance(arg, (Tensor, CSRTensor, COOTensor)):
|
||||
return arg
|
||||
elif hasattr(arg, "__ms_mutable__") and getattr(arg, "__ms_mutable__"):
|
||||
return arg
|
||||
elif context.get_context("grad_for_scalar") and isinstance(arg, (int, float)):
|
||||
return arg
|
||||
elif hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(arg, tuple) and \
|
||||
_check_all_tensor(arg):
|
||||
return arg
|
||||
return None
|
||||
|
||||
|
||||
def _get_args_for_run(obj, args, kwargs):
|
||||
"""Get the actual input args and kwargs for runtime."""
|
||||
new_args = []
|
||||
for arg in args:
|
||||
new_arg = _handle_arg(obj, arg)
|
||||
if new_arg is not None:
|
||||
new_args.append(new_arg)
|
||||
|
||||
for _, value in kwargs.items():
|
||||
new_value = _handle_arg(obj, value)
|
||||
if new_value is not None:
|
||||
new_args.append(new_value)
|
||||
|
||||
return new_args
|
||||
|
||||
|
||||
class _MindsporeFunctionExecutor:
|
||||
|
@ -296,7 +312,7 @@ class _MindsporeFunctionExecutor:
|
|||
self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
|
||||
|
||||
@_wrap_func
|
||||
def __call__(self, *args):
|
||||
def __call__(self, *args, **kwargs):
|
||||
args_list = args
|
||||
if self.obj is not None:
|
||||
args_list = args_list[1:]
|
||||
|
@ -304,10 +320,10 @@ class _MindsporeFunctionExecutor:
|
|||
phase = ""
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
_pynative_executor.set_ms_function_compile_status(True, phase)
|
||||
phase = self.compile(args_list, self.fn.__name__)
|
||||
phase = self.compile(self.fn.__name__, *args_list, **kwargs)
|
||||
_pynative_executor.set_ms_function_compile_status(False, phase)
|
||||
else:
|
||||
phase = self.compile(args_list, self.fn.__name__)
|
||||
phase = self.compile(self.fn.__name__, *args_list, **kwargs)
|
||||
except Exception as err:
|
||||
_pynative_executor.clear_res()
|
||||
raise err
|
||||
|
@ -315,7 +331,7 @@ class _MindsporeFunctionExecutor:
|
|||
if context.get_context("precompile_only"):
|
||||
return None
|
||||
|
||||
new_inputs = self._generate_run_args(args_list)
|
||||
new_inputs = self._generate_run_args(args_list, kwargs)
|
||||
output = self._graph_executor(tuple(new_inputs), phase)
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
output = _pynative_executor.grad_ms_function(output, *new_inputs)
|
||||
|
@ -331,7 +347,7 @@ class _MindsporeFunctionExecutor:
|
|||
|
||||
return output
|
||||
|
||||
def compile(self, args_list, method_name):
|
||||
def compile(self, method_name, *args, **kwargs):
|
||||
"""Returns pipeline for the given args."""
|
||||
# Check whether hook function registered on Cell object.
|
||||
if self.obj and hasattr(self.obj, "_hook_fn_registered"):
|
||||
|
@ -340,9 +356,9 @@ class _MindsporeFunctionExecutor:
|
|||
f"If you want to use hook function, please use context.set_context to set "
|
||||
f"pynative mode and remove 'jit' decorator.")
|
||||
# Chose dynamic shape tensors or actual input tensors as compile args.
|
||||
compile_args = self._generate_compile_args(args_list)
|
||||
compile_args = self._generate_compile_args(args)
|
||||
# Restore the mutable attr for every arg.
|
||||
compile_args = _restore_mutable_attr(args_list, compile_args)
|
||||
compile_args = _restore_mutable_attr(args, compile_args)
|
||||
|
||||
generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + \
|
||||
str(self.fn.__code__.co_firstlineno)
|
||||
|
@ -371,7 +387,7 @@ class _MindsporeFunctionExecutor:
|
|||
self.enable_tuple_broaden = self.obj.enable_tuple_broaden
|
||||
|
||||
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
||||
key = self._graph_executor.generate_arguments_key(self.fn, compile_args, self.enable_tuple_broaden)
|
||||
key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
|
||||
phase = generate_name + '.' + str(key)
|
||||
if phase in ms_compile_cache:
|
||||
return phase
|
||||
|
@ -382,11 +398,11 @@ class _MindsporeFunctionExecutor:
|
|||
self._graph_executor.set_jit_config(self.jit_config_dict)
|
||||
|
||||
if self.obj is None:
|
||||
is_compile = self._graph_executor.compile(self.fn, compile_args, phase, True)
|
||||
is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, True)
|
||||
else:
|
||||
if isinstance(self.obj, ms.nn.Cell):
|
||||
self._graph_executor.set_weights_values(self.obj.parameters_dict())
|
||||
is_compile = self._graph_executor.compile(self.obj, compile_args, phase, True)
|
||||
is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase, True)
|
||||
|
||||
if not is_compile:
|
||||
raise RuntimeError("Executor compile failed.")
|
||||
|
@ -449,17 +465,18 @@ class _MindsporeFunctionExecutor:
|
|||
raise ValueError("The input args is incompatible with the args in `input_signature`!")
|
||||
return compile_args
|
||||
|
||||
def _generate_run_args(self, args_list):
|
||||
def _generate_run_args(self, args_list, kwargs):
|
||||
"""
|
||||
Generate input args, which are required for running.
|
||||
|
||||
Args:
|
||||
args_list (Tuple): Actual input args.
|
||||
kwargs (Dict): Actual input kwargs.
|
||||
|
||||
Returns:
|
||||
new_inputs, new input args, which are required for running.
|
||||
"""
|
||||
return _get_args_for_run(self, args_list)
|
||||
return _get_args_for_run(self, args_list, kwargs)
|
||||
|
||||
|
||||
# The attributes used to identify a given object.
|
||||
|
@ -577,14 +594,15 @@ def jit(fn=None, input_signature=None, hash_args=None, jit_config=None):
|
|||
if os.getenv("MS_JIT") == '0':
|
||||
return func(*args, **kwargs)
|
||||
|
||||
args = _handle_func_args(func, *args, **kwargs)
|
||||
args, kwargs = _handle_func_args(func, *args, **kwargs)
|
||||
|
||||
process_obj = None
|
||||
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
|
||||
process_obj = args[0]
|
||||
# only the function or cell instance wrapped by shard will fall into this branch
|
||||
if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
|
||||
process_obj = hash_args
|
||||
out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj, jit_config)(*args)
|
||||
out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj, jit_config)(*args, **kwargs)
|
||||
return out
|
||||
|
||||
return staging_specialize
|
||||
|
@ -1338,16 +1356,17 @@ class _CellGraphExecutor:
|
|||
if "train" in phase and (enable_compile_cache is True or enable_compile_cache == "1"):
|
||||
self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
|
||||
|
||||
def compile(self, obj, *args, phase='predict', do_convert=True, jit_config_dict=None):
|
||||
def compile(self, obj, *args, phase='predict', do_convert=True, jit_config_dict=None, **kwargs):
|
||||
"""
|
||||
Compiles graph.
|
||||
|
||||
Args:
|
||||
obj (Function/Cell): The function or cell instance need compile.
|
||||
args (tuple): Function or cell input arguments.
|
||||
phase (str): The name of compile phase. Default: 'predict'.
|
||||
do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
|
||||
jit_config_dict (dict): Jit config for compile. Default: None.
|
||||
args (tuple): Args of the Cell object.
|
||||
kwargs (dict): Kwargs of the Cell object.
|
||||
|
||||
Return:
|
||||
Str, the full phase of the cell.
|
||||
|
@ -1357,14 +1376,13 @@ class _CellGraphExecutor:
|
|||
if not hasattr(obj, obj.__parse_method__):
|
||||
raise AttributeError(
|
||||
'The class {} dose not have method {}'.format(obj.__class__.__name__, obj.__parse_method__))
|
||||
args_list = args
|
||||
|
||||
self.enable_tuple_broaden = False
|
||||
if hasattr(obj, "enable_tuple_broaden"):
|
||||
self.enable_tuple_broaden = obj.enable_tuple_broaden
|
||||
|
||||
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
||||
key = self._graph_executor.generate_arguments_key(obj, args_list, self.enable_tuple_broaden)
|
||||
key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
|
||||
obj.arguments_key = str(key)
|
||||
phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
|
||||
|
||||
|
@ -1374,14 +1392,14 @@ class _CellGraphExecutor:
|
|||
|
||||
obj.check_names()
|
||||
_check_full_batch()
|
||||
self._set_dataset_mode(args_list)
|
||||
self._set_dataset_mode(args)
|
||||
self._set_compile_cache_dep_files(phase)
|
||||
|
||||
enable_ge = context.get_context("enable_ge")
|
||||
self._graph_executor.set_weights_values(obj.parameters_dict())
|
||||
if jit_config_dict:
|
||||
self._graph_executor.set_jit_config(jit_config_dict)
|
||||
result = self._graph_executor.compile(obj, args_list, phase, self._use_vm_mode())
|
||||
result = self._graph_executor.compile(obj, args, kwargs, phase, self._use_vm_mode())
|
||||
obj.compile_cache.add(phase)
|
||||
if not result:
|
||||
raise RuntimeError("Executor compile failed.")
|
||||
|
@ -1458,6 +1476,8 @@ class _CellGraphExecutor:
|
|||
Run the specific graph.
|
||||
|
||||
Args:
|
||||
obj (Cell): The cell object.
|
||||
args (tuple): Args of the Cell object.
|
||||
phase (str): The phase name. Default: 'predict'.
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -443,36 +443,38 @@ class Cell(Cell_):
|
|||
output = self._run_forward_hook(cast_inputs, output)
|
||||
return output
|
||||
|
||||
def _check_construct_args(self, *inputs, **kwargs):
|
||||
def _check_construct_args(self, *args):
|
||||
"""Check the args needed by the function construct"""
|
||||
if kwargs:
|
||||
raise ValueError(f"For 'Cell', expect no kwargs here, maybe you pass wrong arguments, "
|
||||
f"or there is a key in kwargs that is not used as a function argument. "
|
||||
f"args: {inputs}, kwargs: {kwargs}")
|
||||
positional_args = 0
|
||||
default_args = 0
|
||||
has_var = False
|
||||
for value in inspect.signature(self.construct).parameters.values():
|
||||
if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD:
|
||||
return
|
||||
has_var = True
|
||||
if value.kind is inspect.Parameter.KEYWORD_ONLY:
|
||||
raise TypeError(f"For the method 'construct', MindSpore does not support keyword-only arg: {value}.")
|
||||
if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
||||
if value.default is inspect.Parameter.empty:
|
||||
positional_args += 1
|
||||
else:
|
||||
default_args += 1
|
||||
|
||||
if len(inputs) < positional_args:
|
||||
if has_var:
|
||||
return
|
||||
|
||||
if len(args) < positional_args:
|
||||
raise TypeError(f"For 'Cell', the function construct requires {positional_args} positional argument, "
|
||||
f"but got {len(inputs)}. When using set_inputs, please make sure that all networks "
|
||||
f"but got {len(args)}. When using set_inputs, please make sure that all networks "
|
||||
f"and loss functions are configured with set_inputs.")
|
||||
|
||||
if len(inputs) > positional_args + default_args:
|
||||
if len(args) > positional_args + default_args:
|
||||
construct_inputs_names = self.construct.__code__.co_varnames
|
||||
if 'self' not in construct_inputs_names:
|
||||
raise TypeError(f"For 'Cell', the method 'construct' must have parameter 'self'. ")
|
||||
|
||||
raise TypeError(f"For 'Cell', the function construct requires {positional_args} positional argument and "
|
||||
f"{default_args} default argument, total {positional_args + default_args}, "
|
||||
f"but got {len(inputs)}.")
|
||||
f"but got {len(args)}.")
|
||||
|
||||
def _hook_fn_registered(self):
|
||||
'''Hook function in graph mode'''
|
||||
|
@ -615,7 +617,7 @@ class Cell(Cell_):
|
|||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self.__class__.construct is Cell.construct:
|
||||
raise AttributeError("For 'Cell', the method 'construct' is not defined. ")
|
||||
raise AttributeError("For 'Cell', the method 'construct' is not defined.")
|
||||
|
||||
if kwargs:
|
||||
bound_arguments = inspect.signature(self.construct).bind(*args, **kwargs)
|
||||
|
@ -625,11 +627,11 @@ class Cell(Cell_):
|
|||
|
||||
# Run in Graph mode.
|
||||
if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
|
||||
self._check_construct_args(*args, **kwargs)
|
||||
self._check_construct_args(*args)
|
||||
if self._hook_fn_registered():
|
||||
logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
|
||||
f"function, please use context.set_context to set pynative mode.")
|
||||
out = self.compile_and_run(*args)
|
||||
out = self.compile_and_run(*args, **kwargs)
|
||||
return out
|
||||
|
||||
# Run in PyNative mode.
|
||||
|
@ -918,24 +920,25 @@ class Cell(Cell_):
|
|||
|
||||
return self._dynamic_shape_inputs
|
||||
|
||||
def compile(self, *inputs):
|
||||
def compile(self, *args, **kwargs):
|
||||
"""
|
||||
Compile Cell as a computation graph, the input must be consistent with the input defined in construct.
|
||||
|
||||
Args:
|
||||
inputs (tuple): Inputs of the Cell object.
|
||||
args (tuple): Args of the Cell object.
|
||||
kwargs (dict): Kwargs of the Cell object.
|
||||
"""
|
||||
if self._dynamic_shape_inputs is None or self._dynamic_shape_inputs[0] is None:
|
||||
_cell_graph_executor.compile(self, *inputs, phase=self.phase,
|
||||
jit_config_dict=self._jit_config_dict)
|
||||
_cell_graph_executor.compile(self, phase=self.phase,
|
||||
jit_config_dict=self._jit_config_dict, *args, **kwargs)
|
||||
else:
|
||||
self._check_compile_dynamic_shape(*inputs)
|
||||
self._check_compile_dynamic_shape(*args)
|
||||
self.saved_dynamic_shape = self._dynamic_shape_inputs
|
||||
_cell_graph_executor.compile(self, *self._dynamic_shape_inputs, phase=self.phase,
|
||||
jit_config_dict=self._jit_config_dict)
|
||||
jit_config_dict=self._jit_config_dict, **kwargs)
|
||||
logger.debug("Compiled Graph with dynamic shape")
|
||||
|
||||
def compile_and_run(self, *inputs):
|
||||
def compile_and_run(self, *args, **kwargs):
|
||||
"""
|
||||
Compile and run Cell, the input must be consistent with the input defined in construct.
|
||||
|
||||
|
@ -943,15 +946,16 @@ class Cell(Cell_):
|
|||
It is not recommended to call directly.
|
||||
|
||||
Args:
|
||||
inputs (tuple): Inputs of the Cell object.
|
||||
args (tuple): Args of the Cell object.
|
||||
kwargs (dict): Kwargs of the Cell object.
|
||||
|
||||
Returns:
|
||||
Object, the result of executing.
|
||||
"""
|
||||
self.compile(*inputs)
|
||||
self.compile(*args, **kwargs)
|
||||
|
||||
new_inputs = _get_args_for_run(self, inputs)
|
||||
return _cell_graph_executor(self, *new_inputs, phase=self.phase)
|
||||
new_args = _get_args_for_run(self, args, kwargs)
|
||||
return _cell_graph_executor(self, *new_args, phase=self.phase)
|
||||
|
||||
def auto_parallel_compile_and_run(self):
|
||||
"""
|
||||
|
@ -1046,7 +1050,7 @@ class Cell(Cell_):
|
|||
f"but got type {type(child_cell)}.")
|
||||
self._cells[child_name] = child_cell
|
||||
|
||||
def construct(self, *inputs, **kwargs):
|
||||
def construct(self, *args, **kwargs):
|
||||
"""
|
||||
Defines the computation to be performed. This method must be overridden by all subclasses.
|
||||
|
||||
|
@ -1054,7 +1058,7 @@ class Cell(Cell_):
|
|||
It is not supported currently that inputs contain both tuple and non-tuple types at same time.
|
||||
|
||||
Args:
|
||||
inputs (tuple): Tuple of variable parameters.
|
||||
args (tuple): Tuple of variable parameters.
|
||||
kwargs (dict): Dictionary of variable keyword parameters.
|
||||
|
||||
Returns:
|
||||
|
@ -2303,13 +2307,13 @@ class GraphCell(Cell):
|
|||
def construct(self, *inputs):
|
||||
return self.graph(*inputs)
|
||||
|
||||
def __call__(self, *inputs):
|
||||
def __call__(self, *args, **kwargs):
|
||||
self.phase = "graph_load_from_mindir"
|
||||
self._add_attr("graph_load_from_mindir", self.graph)
|
||||
if not self.obf_random_seed:
|
||||
return self.compile_and_run(*inputs)
|
||||
return self.compile_and_run(*args, **kwargs)
|
||||
append_input = Tensor((numpy.ones((1, 1)) * self._branch_control_input).astype(numpy.int32))
|
||||
return self.compile_and_run(*inputs, append_input)
|
||||
return self.compile_and_run(*args, append_input, **kwargs)
|
||||
|
||||
|
||||
def _check_param_list_tuple(value):
|
||||
|
|
|
@ -565,17 +565,17 @@ class _Grad(GradOperation_):
|
|||
dynamic_shape_inputs = fn.get_inputs()
|
||||
if self.get_by_position:
|
||||
@jit(input_signature=dynamic_shape_inputs)
|
||||
def after_grad(*args):
|
||||
return grad_(fn, weights, grad_position)(*args)
|
||||
def after_grad(*args, **kwargs):
|
||||
return grad_(fn, weights, grad_position)(*args, **kwargs)
|
||||
else:
|
||||
if self.get_by_list:
|
||||
@jit(input_signature=dynamic_shape_inputs)
|
||||
def after_grad(*args):
|
||||
return grad_(fn, weights)(*args)
|
||||
def after_grad(*args, **kwargs):
|
||||
return grad_(fn, weights)(*args, **kwargs)
|
||||
else:
|
||||
@jit(input_signature=dynamic_shape_inputs)
|
||||
def after_grad(*args):
|
||||
return grad_(fn)(*args)
|
||||
def after_grad(*args, **kwargs):
|
||||
return grad_(fn)(*args, **kwargs)
|
||||
elif self.pynative_:
|
||||
@_wrap_func
|
||||
def after_grad(*args, **kwargs):
|
||||
|
@ -663,8 +663,8 @@ class _Vmap(VmapOperation_):
|
|||
vmap_ = self
|
||||
|
||||
@jit
|
||||
def after_vmap(*args):
|
||||
return vmap_(fn, in_axes, out_axes)(*args)
|
||||
def after_vmap(*args, **kwargs):
|
||||
return vmap_(fn, in_axes, out_axes)(*args, **kwargs)
|
||||
|
||||
self.vmap_fn = after_vmap
|
||||
self.fn = fn
|
||||
|
|
|
@ -1455,7 +1455,7 @@ def _msfunc_info(net, *inputs):
|
|||
# pylint: disable=protected-access
|
||||
net_dict = OrderedDict()
|
||||
_ms_func_executor = _MindsporeFunctionExecutor(net, time.time() * 1e9)
|
||||
graph_id = _ms_func_executor.compile(args_list=inputs, method_name=net.__name__)
|
||||
graph_id = _ms_func_executor.compile(net.__name__, *inputs)
|
||||
mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir')
|
||||
params = _ms_func_executor._graph_executor.get_params(graph_id)
|
||||
for name, value in params.items():
|
||||
|
|
|
@ -17,7 +17,7 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter, ParameterTuple
|
||||
from mindspore import Tensor, Parameter, ParameterTuple, jit
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.ops as ops
|
||||
|
@ -320,3 +320,82 @@ def test_grad_parameter_as_input_and_fv2(mode):
|
|||
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
|
||||
assert np.array_equal(a[1][0].asnumpy(), b[1][0].asnumpy())
|
||||
assert np.array_equal(a[1][1].asnumpy(), b[1][1].asnumpy())
|
||||
|
||||
|
||||
tensor1 = Tensor([1])
|
||||
tensor2 = Tensor([2])
|
||||
tensor3 = Tensor([3])
|
||||
tensor4 = Tensor([4])
|
||||
tensor5 = Tensor([5])
|
||||
tensor6 = Tensor([6])
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_cell_mixed_arguments():
|
||||
"""
|
||||
Feature: Support kwargs for top graph.
|
||||
Description: Mixed arguments for cell.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class FNet(nn.Cell):
|
||||
def construct(self, a, *args, **kwargs):
|
||||
x = a + args[0] + args[1] + kwargs["d"]
|
||||
return x
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = FNet()
|
||||
assert net(tensor1, tensor2, tensor3, b=tensor4, c=tensor5, d=tensor6).asnumpy() == [12]
|
||||
assert net(1, 2, 3, d=tensor6).asnumpy() == [12]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_cell_mixed_arguments_with_grad():
|
||||
"""
|
||||
Feature: Support kwargs for top graph.
|
||||
Description: Mixed arguments for jit function.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class FNet(nn.Cell):
|
||||
def construct(self, *args, **kwargs):
|
||||
x = args[0] + args[1] - kwargs["d"]
|
||||
return x
|
||||
|
||||
class GNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GNet, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = ops.GradOperation()
|
||||
|
||||
def construct(self, *args, **kwargs):
|
||||
gradient_function = self.grad_op(self.net)
|
||||
return gradient_function(*args, **kwargs)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
grad_net = GNet(FNet())
|
||||
assert grad_net(tensor1, tensor2, tensor3, d=tensor4, e=tensor5).asnumpy() == [1]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_jit_mixed_arguments():
|
||||
"""
|
||||
Feature: Support kwargs for top graph.
|
||||
Description: Mixed arguments for jit function.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def func(a, *args, **kwargs):
|
||||
x = a + args[0] + args[1] + kwargs["d"]
|
||||
return x
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
assert func(tensor1, tensor2, tensor3, b=tensor4, c=tensor5, d=tensor6).asnumpy() == [12]
|
||||
assert func(1, 2, 3, d=tensor6).asnumpy() == [12]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -22,8 +22,10 @@ from mindspore.nn.optim import Momentum
|
|||
from mindspore.common.api import jit
|
||||
from mindspore.common import Parameter, ParameterTuple
|
||||
import mindspore.context as context
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
|
||||
@jit
|
||||
def ConvBnReLU(x):
|
||||
conv = nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
|
||||
|
@ -36,6 +38,7 @@ def ConvBnReLU(x):
|
|||
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -262,6 +265,7 @@ def test_pynative_ms_function():
|
|||
assert np.allclose(out_a[0][0].asnumpy(), out_b[0][0].asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(out_a[1][0].asnumpy(), out_b[1][0].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -290,6 +294,7 @@ def test_pynative_ms_function_mix_execute():
|
|||
output = net(a, b)
|
||||
assert output == 8
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -324,6 +329,7 @@ def test_pynative_ms_function_empty_graph():
|
|||
output = net()
|
||||
assert output.asnumpy() == 10
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -363,6 +369,7 @@ def test_pynative_ms_function_control_flow_if_break():
|
|||
output = net(x, y, z)
|
||||
assert (output.asnumpy() == z.asnumpy() * 4).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -410,7 +417,6 @@ def test_pynative_ms_function_with_tuple_inputs():
|
|||
new_grads.append(grad + 1)
|
||||
return new_grads
|
||||
|
||||
|
||||
x = Tensor(np.ones([2, 2]), dtype=ms.int32)
|
||||
y = Tensor(np.ones([2, 2]), dtype=ms.int32)
|
||||
net = Net()
|
||||
|
@ -480,7 +486,6 @@ def test_pynative_ms_function_with_kwargs_inputs():
|
|||
def foo(x, **kwargs):
|
||||
return x + kwargs.get('y')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
x = Tensor(3, dtype=ms.int32)
|
||||
data = {"y": 1}
|
||||
foo(x, **data)
|
||||
x = Tensor(3, dtype=ms.int32)
|
||||
data = {"y": 1}
|
||||
assert foo(x, **data).asnumpy() == [4]
|
||||
|
|
|
@ -83,20 +83,20 @@ def test_ms_function_tensor_compile_phase1():
|
|||
ms_create_time = int(time.time() * 1e9)
|
||||
_ms_function_executor = _MindsporeFunctionExecutor(fn, ms_create_time)
|
||||
# The ms_function makes the tensor inputs mutable by default
|
||||
phase1 = _ms_function_executor.compile((x, y), "fn")
|
||||
phase2 = _ms_function_executor.compile((p, q), "fn")
|
||||
phase1 = _ms_function_executor.compile("fn", x, y)
|
||||
phase2 = _ms_function_executor.compile("fn", p, q)
|
||||
assert phase1 != phase2
|
||||
# mutable api
|
||||
phase1 = _ms_function_executor.compile((mutable(x), mutable(y)), "fn")
|
||||
phase2 = _ms_function_executor.compile((mutable(p), mutable(q)), "fn")
|
||||
phase1 = _ms_function_executor.compile("fn", mutable(x), mutable(y))
|
||||
phase2 = _ms_function_executor.compile("fn", mutable(p), mutable(q))
|
||||
assert phase1 == phase2
|
||||
# set_mutable api of Tensor
|
||||
x.set_const_arg(False)
|
||||
y.set_const_arg(False)
|
||||
p.set_const_arg(False)
|
||||
q.set_const_arg(False)
|
||||
phase1 = _ms_function_executor.compile((x, y), "fn")
|
||||
phase2 = _ms_function_executor.compile((p, q), "fn")
|
||||
phase1 = _ms_function_executor.compile("fn", x, y)
|
||||
phase2 = _ms_function_executor.compile("fn", p, q)
|
||||
assert phase1 == phase2
|
||||
|
||||
|
||||
|
@ -157,20 +157,20 @@ def test_ms_function_tensor_compile_phase2():
|
|||
q = Tensor([[0.01, 3.0, 1.1], [1.0, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
ms_create_time = int(time.time() * 1e9)
|
||||
_ms_function_executor = _MindsporeFunctionExecutor(fn, ms_create_time)
|
||||
phase1 = _ms_function_executor.compile((x, y), "fn")
|
||||
phase2 = _ms_function_executor.compile((p, q), "fn")
|
||||
phase1 = _ms_function_executor.compile("fn", x, y)
|
||||
phase2 = _ms_function_executor.compile("fn", p, q)
|
||||
assert phase1 == phase2
|
||||
# Set const arg.
|
||||
x.set_const_arg()
|
||||
y.set_const_arg()
|
||||
p.set_const_arg()
|
||||
q.set_const_arg()
|
||||
phase1 = _ms_function_executor.compile((x, y), "fn")
|
||||
phase2 = _ms_function_executor.compile((p, q), "fn")
|
||||
phase1 = _ms_function_executor.compile("fn", x, y)
|
||||
phase2 = _ms_function_executor.compile("fn", p, q)
|
||||
assert phase1 != phase2
|
||||
# mutable api
|
||||
phase1 = _ms_function_executor.compile((mutable(x), mutable(y)), "fn")
|
||||
phase2 = _ms_function_executor.compile((mutable(p), mutable(q)), "fn")
|
||||
phase1 = _ms_function_executor.compile("fn", mutable(x), mutable(y))
|
||||
phase2 = _ms_function_executor.compile("fn", mutable(p), mutable(q))
|
||||
assert phase1 == phase2
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -12,11 +12,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.ops.composite as C
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import Tensor, Parameter, jit
|
||||
from mindspore import nn
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import operations as P
|
||||
|
@ -323,6 +324,7 @@ def test_args_kwarg_not_used():
|
|||
Description: Function with unused parameters which are varargs and kwargs.
|
||||
Expectation: compile success and result == 0
|
||||
"""
|
||||
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, **kwargs):
|
||||
return 0
|
||||
|
@ -344,6 +346,7 @@ def test_args_kwonlyargs_1_kwarg_not_used():
|
|||
Description: Function with unused parameters which are varargs, 1 kwonlyargs and kwargs.
|
||||
Expectation: compile success and result == 0
|
||||
"""
|
||||
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, **kwargs):
|
||||
return 0
|
||||
|
@ -365,6 +368,7 @@ def test_args_kwonlyargs_2_kwarg_not_used():
|
|||
Description: Function with unused parameters which are varargs, 2 kwonlyargs and kwargs.
|
||||
Expectation: compile success and result == 0
|
||||
"""
|
||||
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return 0
|
||||
|
@ -386,6 +390,7 @@ def test_args_1_used_kwonlyargs_kwarg_not_used():
|
|||
Description: Function with unused parameters which are 1 kwonlyargs and kwargs.
|
||||
Expectation: compile success and result == x
|
||||
"""
|
||||
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, **kwargs):
|
||||
return args[0]
|
||||
|
@ -407,6 +412,7 @@ def test_args_2_used_kwonlyargs_kwarg_not_used():
|
|||
Description: Function with unused parameters which are 1 kwonlyargs and kwargs.
|
||||
Expectation: compile success and result == y
|
||||
"""
|
||||
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, **kwargs):
|
||||
return args[1]
|
||||
|
@ -428,6 +434,7 @@ def test_kwonlyargs_1_used_args_kwarg_not_used():
|
|||
Description: Function with unused parameters which are varargs and kwargs.
|
||||
Expectation: compile success and result == only1
|
||||
"""
|
||||
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, **kwargs):
|
||||
return only1
|
||||
|
@ -449,6 +456,7 @@ def test_kwonlyargs_2_used_args_kwarg_not_used():
|
|||
Description: Function with unused parameters which are varargs and kwargs.
|
||||
Expectation: compile success and result == only2
|
||||
"""
|
||||
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return only2
|
||||
|
@ -470,6 +478,7 @@ def test_kwarg_used_args_kwonlyargs_not_used():
|
|||
Description: Function with unused parameters which are varargs and kwonlyargs.
|
||||
Expectation: compile success and result == kw1
|
||||
"""
|
||||
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return kwargs["kw1"]
|
||||
|
@ -490,6 +499,7 @@ def test_args_1_kwonlyargs_1_used_kwarg_not_used():
|
|||
Description: Function with unused parameters which are kwargs.
|
||||
Expectation: compile success and result == (x, 3)
|
||||
"""
|
||||
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return (args[0], only1)
|
||||
|
@ -510,6 +520,7 @@ def test_args_2_kwonlyargs_1_used_kwarg_not_used():
|
|||
Description: Function with unused parameters which are kwargs.
|
||||
Expectation: compile success and result == (x, y, 3)
|
||||
"""
|
||||
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return (args[0], args[1], only1)
|
||||
|
@ -530,6 +541,7 @@ def test_args_2_kwonlyargs_2_used_kwarg_not_used():
|
|||
Description: Function with unused parameters which are kwargs.
|
||||
Expectation: compile success and result == (x, y, only1, only2)
|
||||
"""
|
||||
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return (args[0], args[1], only1, only2)
|
||||
|
@ -550,6 +562,7 @@ def test_kwonlyargs_1_kwarg_used_args_not_used():
|
|||
Description: Function with unused parameters which are varargs.
|
||||
Expectation: compile success and result == (y, kw1)
|
||||
"""
|
||||
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return (only1, kwargs["kw1"])
|
||||
|
@ -570,6 +583,7 @@ def test_kwonlyargs_2_kwarg_used_args_not_used():
|
|||
Description: Function with unused parameters which are varargs.
|
||||
Expectation: compile success and result == (only1, only2, kw1)
|
||||
"""
|
||||
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return (only1, only2, kwargs["kw1"])
|
||||
|
@ -590,6 +604,7 @@ def test_args_1_kwarg_used_kwonlyargs_not_used():
|
|||
Description: Function with unused parameters which are kwonlyargs.
|
||||
Expectation: compile success and result == (x, kw1)
|
||||
"""
|
||||
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return (args[0], kwargs["kw1"])
|
||||
|
@ -610,6 +625,7 @@ def test_args_2_kwarg_used_kwonlyargs_not_used():
|
|||
Description: Function with unused parameters which are kwonlyargs.
|
||||
Expectation: compile success and result == (x, y, kw1)
|
||||
"""
|
||||
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return (args[0], args[1], kwargs["kw1"])
|
||||
|
@ -630,6 +646,7 @@ def test_args_1_kwonlyargs_1_kwarg_used():
|
|||
Description: Function with unused parameters which is kwonlyarg.
|
||||
Expectation: compile success and result == (x, only1, kw1)
|
||||
"""
|
||||
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return (args[0], only1, kwargs["kw1"])
|
||||
|
@ -650,6 +667,7 @@ def test_args_2_kwonlyargs_2_kwarg_used():
|
|||
Description: Function without unused parameters.
|
||||
Expectation: compile success and result == (x, y, only1, only2, kw1)
|
||||
"""
|
||||
|
||||
class Net(Cell):
|
||||
def trivial(self, *args, only1=3, only2=4, **kwargs):
|
||||
return (args[0], args[1], only1, only2, kwargs["kw1"])
|
||||
|
@ -662,3 +680,305 @@ def test_args_2_kwonlyargs_2_kwarg_used():
|
|||
x = 1
|
||||
y = 2
|
||||
assert net(x, y) == (x, y, 3, 4, 5)
|
||||
|
||||
|
||||
def test_cell_keyword_argument():
|
||||
"""
|
||||
Feature: Support kwargs for top graph.
|
||||
Description: Only positional arguments.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, a, b):
|
||||
return a * b
|
||||
|
||||
net = Net()
|
||||
assert net(2, b=3) == 6
|
||||
assert net(a=2, b=3) == 6
|
||||
assert net(b=3, a=2) == 6
|
||||
|
||||
|
||||
def test_cell_default_argument():
|
||||
"""
|
||||
Feature: Support kwargs for top graph.
|
||||
Description: Positional arguments with default values.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, y=3, z=4):
|
||||
return x ** y + z
|
||||
|
||||
net = Net()
|
||||
assert net(2) == 12
|
||||
assert net(2, 1) == 6
|
||||
assert net(2, 3, 2) == 10
|
||||
assert net(y=1, z=3, x=2) == 5
|
||||
|
||||
|
||||
def test_cell_args1():
|
||||
"""
|
||||
Feature: Support kwargs for top graph.
|
||||
Description: Only varargs.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, *args):
|
||||
x = 0
|
||||
for arg in args:
|
||||
x = x + arg
|
||||
return x
|
||||
|
||||
net = Net()
|
||||
assert net(1, 2, 3) == 6
|
||||
|
||||
|
||||
def test_cell_args2():
|
||||
"""
|
||||
Feature: Support kwargs for top graph.
|
||||
Description: Positional arguments and varargs.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, *args):
|
||||
for arg in args:
|
||||
x = x + arg
|
||||
return x
|
||||
|
||||
net = Net()
|
||||
assert net(1, 2, 3) == 6
|
||||
|
||||
|
||||
def test_cell_kwargs1():
|
||||
"""
|
||||
Feature: Support kwargs for top graph.
|
||||
Description: Only kwarg.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, **kwargs):
|
||||
return kwargs["a"] + kwargs["b"]
|
||||
|
||||
net = Net()
|
||||
assert net(a=1, b=2, c=3) == 3
|
||||
|
||||
|
||||
def test_cell_kwargs2():
|
||||
"""
|
||||
Feature: Support kwargs for top graph.
|
||||
Description: Positional arguments and kwarg.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x, **kwargs):
|
||||
return x + kwargs["a"] - kwargs["b"]
|
||||
|
||||
net = Net()
|
||||
assert net(1, a=2, b=3) == 0
|
||||
|
||||
|
||||
def test_cell_args_kwargs():
|
||||
"""
|
||||
Feature: Support kwargs for top graph.
|
||||
Description: Vararg and kwarg.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, *args, **kwargs):
|
||||
x = args[0] + args[1] - kwargs["c"] + kwargs["d"]
|
||||
return x
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
assert net(1, 2, c=3, d=4) == 4
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='kwonly not support')
|
||||
def test_cell_kwonly1():
|
||||
"""
|
||||
Feature: Support kwargs for top graph.
|
||||
Description: Only kwonly arguments.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, *, a, b):
|
||||
x = a + b
|
||||
return x
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
assert net(a=1, b=2) == 3
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='kwonly not support')
|
||||
def test_cell_kwonly2():
|
||||
"""
|
||||
Feature: Support kwargs for top graph.
|
||||
Description: Positional args and kwonly arguments with default values.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, a, *, b, c=3):
|
||||
x = a + b - c
|
||||
return x
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
assert net(1, b=2, c=3) == 0
|
||||
assert net(1, b=2) == 0
|
||||
|
||||
|
||||
def test_cell_mixed_arguments1():
|
||||
"""
|
||||
Feature: Support kwargs for top graph.
|
||||
Description: Mixed arguments.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, a, b, c=3, *args, **kwargs):
|
||||
x = a + b - c + args[0] - args[1] + kwargs["d"]
|
||||
return x
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
assert net(1, 2, 3, 4, 5, d=6, e=7) == 5
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='kwonly not support')
|
||||
def test_cell_mixed_arguments2():
|
||||
"""
|
||||
Feature: Support kwargs for top graph.
|
||||
Description: Mixed arguments.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, a, *args, b, c=1, **kwargs):
|
||||
x = a + args[0] - args[1] + b - c + kwargs["d"]
|
||||
return x
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
assert net(1, 2, 3, b=4, c=5, d=6) == 5
|
||||
assert net(1, 2, 3, b=4, d=6) == 9
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='kwonly not support')
|
||||
def test_cell_mixed_arguments3():
|
||||
"""
|
||||
Feature: Support kwargs for top graph.
|
||||
Description: Mixed arguments.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, a, *, b, c=1, **kwargs):
|
||||
x = a + b - c + kwargs["d"]
|
||||
return x
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
assert net(1, b=2, c=3, d=4) == 4
|
||||
assert net(1, b=4, d=6) == 10
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='kwonly not support')
|
||||
def test_cell_mixed_arguments_with_dict_input():
|
||||
"""
|
||||
Feature: Support kwargs for top graph.
|
||||
Description: Mixed arguments with dictionary.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, a, *args, b, c=1, **kwargs):
|
||||
x = a["item0"] + args[0] + args[1]["item1"] + b["item2"] + c + kwargs["d"]["item3"]
|
||||
return x
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
assert net({"item0": 1}, 2, {"item1": 3}, b={"item2": 4}, c=5, d={"item3": 6}) == 21
|
||||
|
||||
|
||||
def test_cell_mixed_arguments_with_sub_cell():
|
||||
"""
|
||||
Feature: Support kwargs for top graph.
|
||||
Description: Mixed arguments with sub cell.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class SubNet(nn.Cell):
|
||||
def construct(self, a, *args, b, c=1, **kwargs):
|
||||
x = a + args[0] + args[1] + b + c + kwargs["d"]
|
||||
return x
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.subnet = SubNet()
|
||||
|
||||
def construct(self, a, arg0, arg1, b, c, d):
|
||||
x = self.subnet(a, arg0, arg1, b=b, c=c, d=d)
|
||||
return x
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
assert net(1, 2, 3, 4, 5, 6) == 21
|
||||
|
||||
|
||||
def test_jit_kwargs():
|
||||
"""
|
||||
Feature: Support kwargs for top graph.
|
||||
Description: Vararg and kwarg for jit function.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def func(*args, **kwargs):
|
||||
x = args[0] + args[1] - kwargs["d"]
|
||||
return x
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
assert func(1, 2, c=3, d=4) == -1
|
||||
assert func(1, 2, d=4) == -1
|
||||
assert func(1, 2, d=4, e=5) == -1
|
||||
assert func(1, 2, 2.1, 2.2, d=4, e=5, f=6) == -1
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
assert func(1, 2, c=3, d=4) == -1
|
||||
assert func(1, 2, d=4) == -1
|
||||
assert func(1, 2, d=4, e=5) == -1
|
||||
assert func(1, 2, 2.1, 2.2, d=4, e=5, f=6) == -1
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='kwonly not support')
|
||||
def test_jit_mixed_arguments():
|
||||
"""
|
||||
Feature: Support kwargs for top graph.
|
||||
Description: Vararg and kwarg for jit function.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def func(a, *args, b, c=5, **kwargs):
|
||||
x = a + args[0] - args[1] + b - c + kwargs["d"]
|
||||
return x
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
assert func(1, 2, 3, b=4, c=5, d=6) == 5
|
||||
assert func(1, 2, 3, b=4, d=6) == 5
|
||||
assert func(1, 2, 3, 4, b=5, c=6, d=7) == 6
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
assert func(1, 2, 3, b=4, c=5, d=6) == 5
|
||||
assert func(1, 2, 3, b=4, d=6) == 5
|
||||
assert func(1, 2, 3, 4, b=5, c=6, d=7) == 6
|
||||
|
|
Loading…
Reference in New Issue