!46993 support kwargs for top graph

Merge pull request !46993 from huanghui/support-kwargs
This commit is contained in:
i-robot 2023-02-17 01:30:30 +00:00 committed by Gitee
commit ce38106313
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
17 changed files with 658 additions and 180 deletions

View File

@ -86,8 +86,9 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_abs_l
return res_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(key_value), dict_get_item});
});
} else {
MS_LOG(EXCEPTION) << "The arguments of UnpackCall operator should be tuple, list or dict, but got "
<< args_abs_list[index]->ToString();
// For the mixed arguments: func(a, *args)
AnfNodePtr param = res_graph->add_parameter();
(void)elems.emplace_back(param);
}
}
// Add to order list to trace if fn_node had side effect.

View File

@ -157,8 +157,8 @@ PYBIND11_MODULE(_c_expression, m) {
py::arg("incremental") = py::bool_(false), py::arg("obf_ratio") = py::float_(1.0),
py::arg("branch_control_input") = py::int_(0), "Get graph proto of dynamic-obfuscated model.")
.def("get_params", &GraphExecutorPy::GetParams, py::arg("phase") = py::str(""), "Get Parameters from graph")
.def("compile", &GraphExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""),
py::arg("use_vm") = py::bool_(false), "Compile obj by executor.")
.def("compile", &GraphExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("kwargs"),
py::arg("phase") = py::str(""), py::arg("use_vm") = py::bool_(false), "Compile obj by executor.")
.def("updata_param_node_default_input", &GraphExecutorPy::UpdataParamNodeDefaultInput, py::arg("phase"),
py::arg("params"), "Fetch the inputs of Conv or Matmul for quant export.")
.def("get_parameter_layout", &GraphExecutorPy::GetParameterLayout, py::arg("phase") = py::str("train"),

View File

@ -652,8 +652,8 @@ FunctionBlockPtr Parser::ParseDefFunction(const py::object &node, const Function
// Save the function node to block
func_block->WriteVariable(function_name, NewValueNode(current_fg));
py::object funcObj = python_adapter::GetPyObjAttr(node, "body");
(void)ParseStatements(func_block, funcObj);
py::object func_obj = python_adapter::GetPyObjAttr(node, "body");
(void)ParseStatements(func_block, func_obj);
// Add unused variables as isolate nodes.
for (auto &func_block_item : func_block_list_) {
@ -1337,15 +1337,14 @@ void Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object
values.push_back(ret_node);
}
}
auto keys_tuple = GenerateMakeTuple(block, keys);
auto values_tuple = GenerateMakeTuple(block, values);
auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
std::vector<AnfNodePtr> make_dict_nodes;
make_dict_nodes.push_back(make_dict_op);
make_dict_nodes.push_back(keys_tuple);
make_dict_nodes.push_back(values_tuple);
MS_EXCEPTION_IF_NULL(block->func_graph());
args_context->packed_arguments.push_back(block->func_graph()->NewCNodeInOrder(std::move(make_dict_nodes)));
if (!keys.empty()) {
auto keys_tuple = GenerateMakeTuple(block, keys);
auto values_tuple = GenerateMakeTuple(block, values);
auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
std::vector<AnfNodePtr> make_dict_nodes = {make_dict_op, keys_tuple, values_tuple};
MS_EXCEPTION_IF_NULL(block->func_graph());
args_context->packed_arguments.push_back(block->func_graph()->NewCNodeInOrder(std::move(make_dict_nodes)));
}
}
}
@ -3562,7 +3561,7 @@ FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
param->set_name(name);
MS_EXCEPTION_IF_NULL(param->debug_info());
param->debug_info()->set_name(name);
param->debug_info()->set_location(param->debug_info()->location());
param->debug_info()->set_location(orig_param->debug_info()->location());
param->set_is_top_graph_param(true);
}
func_graph->set_has_vararg(current_graph->has_vararg());

View File

@ -76,6 +76,7 @@
#include "include/backend/data_queue/data_queue_mgr.h"
#ifndef ENABLE_SECURITY
#include "debug/data_dump/dump_json_parser.h"
#include "abstract/abstract_value.h"
#endif
#if defined(__linux__) && defined(WITH_BACKEND)
#include "ps/constants.h"
@ -403,19 +404,18 @@ void CheckArgsValid(const py::object &source_obj, const py::tuple &args) {
}
}
py::object GraphExecutorPy::GenerateArgumentsKey(const py::object &obj, const py::tuple &args,
py::object GraphExecutorPy::GenerateArgumentsKey(const py::object &obj, const py::tuple &args, const py::dict &kwargs,
bool enable_tuple_broaden) {
MS_LOG(DEBUG) << "GenerateArgumentsKey args size: " << args.size()
<< ", enable_tuple_broaden: " << enable_tuple_broaden;
abstract::AbstractBasePtrList args_abs;
cur_convert_input_.clear();
std::size_t size = args.size();
for (std::size_t i = 0; i < size; i++) {
ClearCurConvertInput();
for (std::size_t i = 0; i < args.size(); i++) {
ValuePtr converted = nullptr;
if (!parse::ConvertData(args[i], &converted)) {
MS_EXCEPTION(TypeError) << "parse::ConvertData for " << i << "th argument failed, the argument type is "
<< args[i].get_type() << ", value is '" << py::str(args[i]) << "'.";
MS_LOG(EXCEPTION) << "parse::ConvertData for " << i << "th argument failed, the argument type is "
<< args[i].get_type() << ", value is '" << py::str(args[i]) << "'.";
}
AbstractBasePtr abs = ArgsToAbstract(args[i], converted, enable_tuple_broaden);
(void)args_abs.emplace_back(abs);
@ -423,6 +423,20 @@ py::object GraphExecutorPy::GenerateArgumentsKey(const py::object &obj, const py
// so we keep all inputs for subsequent procedure.
(void)cur_convert_input_.emplace(args[i].ptr(), std::make_pair(converted, abs));
}
for (const auto &item : kwargs) {
ValuePtr key = nullptr;
ValuePtr value = nullptr;
bool success = parse::ConvertData(py::cast<py::object>(item.first), &key) &&
parse::ConvertData(py::cast<py::object>(item.second), &value);
if (!success) {
MS_LOG(EXCEPTION) << "parse::ConvertData for argument (" << py::str(item.first) << ": " << py::str(item.second)
<< ") failed.";
}
AbstractBasePtr value_abs = ArgsToAbstract(py::cast<py::object>(item.second), value, enable_tuple_broaden);
auto keyword_arg_abs = std::make_shared<abstract::AbstractKeywordArg>(GetValue<std::string>(key), value_abs);
(void)args_abs.emplace_back(keyword_arg_abs);
(void)cur_convert_input_.emplace(item.first.ptr(), std::make_pair(value, keyword_arg_abs));
}
// If cache matched no need CheckArgsValid
auto iter = kArgsCache.find(args_abs);
@ -867,8 +881,8 @@ void GraphExecutorPy::CleanCompileRes(const ResourcePtr &resource) {
MS_LOG(INFO) << "Clean compile resource end";
}
bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple &args, const py::object &phase_obj,
bool use_vm) {
bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple &args, const py::dict &kwargs,
const py::object &phase_obj, bool use_vm) {
// Check if the phase is valid.
if ((!py::isinstance<py::str>(phase_obj))) {
MS_LOG(ERROR) << "The `phase` must be string.";
@ -887,7 +901,8 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
phase_ = phase;
auto obj_desc = GetObjDesc(source_obj);
MS_LOG(INFO) << "Start compiling, phase: " << phase;
MS_LOG(DEBUG) << "source: {" << py::str(source_obj) << "}\nargs: " << py::str(const_cast<py::tuple &>(args));
MS_LOG(DEBUG) << "source: {" << py::str(source_obj) << "}\nargs: " << py::str(const_cast<py::tuple &>(args))
<< "\nkwargs: " << py::str(const_cast<py::dict &>(kwargs));
EventMessage::PrintCompileStartMsg(phase, obj_desc);
ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
@ -914,42 +929,15 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
// Get the parameters items and add the value to args_abs.
abstract::AbstractBasePtrList args_abs;
std::vector<ValuePtr> arguments;
std::size_t size = args.size();
MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
bool is_parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kSemiAutoParallel ||
parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kAutoParallel;
bool is_auto_parallel = is_parallel_mode && !py::hasattr(source_obj, parallel::kSkipAutoParallelCompile) &&
!py::hasattr(source_obj, parallel::kKeepInputUnchanged);
for (std::size_t i = 0; i < size; i++) {
ValuePtr converted = nullptr;
// In some parallel mode need full_tensor which cause the args of GenerateArgumentsKey not same to compile,
// So can't use cur_convert_input_ directly.
auto iter = cur_convert_input_.find(args[i].ptr());
if (iter != cur_convert_input_.end()) {
(void)arguments.emplace_back(iter->second.first);
if (is_auto_parallel) {
auto abs_item = iter->second.second->Clone();
(void)parallel::ExtendInputArgsAbstractShape(abs_item, i);
(void)args_abs.emplace_back(abs_item);
continue;
}
(void)args_abs.emplace_back(iter->second.second);
continue;
}
bool succ = parse::ConvertData(args[i], &converted);
if (!succ) {
MS_LOG(EXCEPTION) << "Fail to convert the " << i << "th argument, args[" << i << "]: " << py::str(args[i]);
}
(void)arguments.emplace_back(converted);
auto args_abstract_item = ArgsToAbstract(args[i], converted, enable_tuple_broaden_);
if (is_auto_parallel) {
(void)parallel::ExtendInputArgsAbstractShape(args_abstract_item, i);
}
(void)args_abs.emplace_back(args_abstract_item);
}
ConvertArgs(args, kwargs, is_auto_parallel, &args_abs, &arguments);
resource->set_arguments(arguments);
resource->set_args_abs(args_abs);
executor_info->arg_list_size = size;
executor_info->arg_list_size = args.size() + kwargs.size();
executor_info->resource = resource;
info_[phase] = executor_info;
pip->Run();
@ -969,6 +957,59 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
return true;
}
void GraphExecutorPy::ConvertArgs(const py::tuple &args, const py::dict &kwargs, bool is_auto_parallel,
abstract::AbstractBasePtrList *args_abs, std::vector<ValuePtr> *arguments) {
MS_EXCEPTION_IF_NULL(args_abs);
MS_EXCEPTION_IF_NULL(arguments);
for (std::size_t i = 0; i < args.size(); i++) {
// In some parallel mode need full_tensor which cause the args of GenerateArgumentsKey not same to compile,
// So can't use cur_convert_input_ directly.
auto iter = cur_convert_input_.find(args[i].ptr());
if (iter != cur_convert_input_.end()) {
(void)arguments->emplace_back(iter->second.first);
if (is_auto_parallel) {
auto abs_item = iter->second.second->Clone();
(void)parallel::ExtendInputArgsAbstractShape(abs_item, i);
(void)args_abs->emplace_back(abs_item);
continue;
}
(void)args_abs->emplace_back(iter->second.second);
continue;
}
ValuePtr converted = nullptr;
bool success = parse::ConvertData(args[i], &converted);
if (!success) {
MS_LOG(EXCEPTION) << "Fail to convert the " << i << "th argument, args[" << i << "]: " << py::str(args[i]);
}
(void)arguments->emplace_back(converted);
auto args_abstract_item = ArgsToAbstract(args[i], converted, enable_tuple_broaden_);
if (is_auto_parallel) {
(void)parallel::ExtendInputArgsAbstractShape(args_abstract_item, i);
}
(void)args_abs->emplace_back(args_abstract_item);
}
for (const auto &item : kwargs) {
auto iter = cur_convert_input_.find(item.first.ptr());
if (iter != cur_convert_input_.end()) {
(void)arguments->emplace_back(iter->second.first);
(void)args_abs->emplace_back(iter->second.second);
continue;
}
ValuePtr key = nullptr;
ValuePtr value = nullptr;
bool success = parse::ConvertData(py::cast<py::object>(item.first), &key) &&
parse::ConvertData(py::cast<py::object>(item.second), &value);
if (!success) {
MS_LOG(EXCEPTION) << "Fail to convert the argument (" << py::str(item.first) << ": " << py::str(item.second)
<< ").";
}
AbstractBasePtr value_abs = ArgsToAbstract(py::cast<py::object>(item.second), value, enable_tuple_broaden_);
auto keyword_arg_abs = std::make_shared<abstract::AbstractKeywordArg>(GetValue<std::string>(key), value_abs);
(void)arguments->emplace_back(value);
(void)args_abs->emplace_back(keyword_arg_abs);
}
}
std::vector<ActionItem> GraphExecutorPy::FilterActions(const std::vector<ActionItem> &actions,
const std::string &phase) {
// filter action after validate when 'export'.
@ -1002,12 +1043,12 @@ void GraphExecutorPy::ReleaseResource(const py::object &phase) {
}
}
bool GraphExecutorPy::Compile(const py::object &source_obj, const py::tuple &args, const py::object &phase,
bool use_vm) {
bool GraphExecutorPy::Compile(const py::object &source_obj, const py::tuple &args, const py::dict &kwargs,
const py::object &phase, bool use_vm) {
bool ret_value = false;
try {
ProcessStatus::GetInstance().RecordStart("CompileInner");
ret_value = CompileInner(source_obj, args, phase, use_vm);
ret_value = CompileInner(source_obj, args, kwargs, phase, use_vm);
ProcessStatus::GetInstance().RecordEnd();
ProcessStatus::GetInstance().Print();
} catch (const py::error_already_set &ex) {
@ -1277,9 +1318,8 @@ bool Pipeline::NeedCreateBackend() {
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) {
MS_EXCEPTION_IF_NULL(arg_list);
std::size_t size = args.size();
bool arg_list_inited = !arg_list->empty();
for (std::size_t i = 0; i < size; i++) {
for (std::size_t i = 0; i < args.size(); i++) {
py::object arg = args[i];
ValuePtr converted = nullptr;
bool succ = parse::ConvertData(arg, &converted);

View File

@ -77,8 +77,12 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
const std::string &phase() const { return phase_; }
const std::map<std::string, std::string> &jit_config() const { return jit_config_; }
void SaveCompiledGraph(const std::string &phase);
bool CompileInner(const py::object &source_obj, const py::tuple &args, const py::object &phase_obj, bool use_vm);
bool Compile(const py::object &source_obj, const py::tuple &args, const py::object &phase, bool use_vm);
void ConvertArgs(const py::tuple &args, const py::dict &kwargs, bool is_auto_parallel,
abstract::AbstractBasePtrList *args_abs, std::vector<ValuePtr> *arguments);
bool CompileInner(const py::object &source_obj, const py::tuple &args, const py::dict &kwargs,
const py::object &phase_obj, bool use_vm);
bool Compile(const py::object &source_obj, const py::tuple &args, const py::dict &kwargs, const py::object &phase,
bool use_vm);
void ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *const arg_list);
@ -131,7 +135,8 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
#endif
// Generate a key for mapping function graph
py::object GenerateArgumentsKey(const py::object &obj, const py::tuple &args, bool enable_tuple_broaden = false);
py::object GenerateArgumentsKey(const py::object &obj, const py::tuple &args, const py::dict &kwargs,
bool enable_tuple_broaden = false);
void ClearCurConvertInput();

View File

@ -209,10 +209,11 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP
return *key_value == *item.first->BuildValue();
});
if (it == dict_elems.end()) {
// For dict[key], if key is not exist, will raise a KeyError exception.
// For dict[key], if key is not exist, will raise a ValueError exception.
// For dict.get('key', default=None), if key is not exist, will return the default value during dict_get.
MS_EXCEPTION(KeyError) << "The key " << key_value->ToString()
<< " does not exist in the dict:" << args_spec_list[0]->BuildValue()->ToString();
// Python KeyError will print escape character. So use ValueError instead of KeyError here.
MS_EXCEPTION(ValueError) << "The key " << key_value->ToString()
<< " does not exist in the dict:" << args_spec_list[0]->BuildValue()->ToString();
}
return it->second;
}

View File

@ -38,7 +38,7 @@ std::string Keyword::ToString() const {
MS_EXCEPTION_IF_NULL(value_);
buffer << "Keyword[";
buffer << "key : " << key_;
buffer << "value : " << value_->ToString();
buffer << ", value : " << value_->ToString();
buffer << "]";
}
return buffer.str();

View File

@ -303,7 +303,7 @@ class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
void GenerateKwParams(const FuncGraphPtr &specialized_graph,
const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list,
const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list, int pos_args_input_count,
std::vector<AnfNodePtr> *specialized_parameter_list,
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const;

View File

@ -122,25 +122,30 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, int var
void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list,
std::vector<AnfNodePtr> *specialized_parameter_list,
int pos_args_input_count, std::vector<AnfNodePtr> *specialized_parameter_list,
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
MS_EXCEPTION_IF_NULL(specialized_parameter_list);
MS_EXCEPTION_IF_NULL(repl_nodes);
MS_EXCEPTION_IF_NULL(specialized_graph);
std::vector<AnfNodePtr> kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
std::vector<AnfNodePtr> kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)};
std::set<AnfNodePtr> kwarg_nodes;
for (const auto &kwarg : kwarg_list) {
for (size_t i = 0; i < kwarg_list.size(); ++i) {
auto kwarg = kwarg_list[i];
MS_EXCEPTION_IF_NULL(kwarg);
std::string kw_param_name = kwarg->get_key();
MS_EXCEPTION_IF_NULL(specialized_graph);
AnfNodePtr param_node = specialized_graph->GetParameterByName(kw_param_name);
// If not find corresponding parameter node.
if (param_node == nullptr) {
if (!has_kwarg()) {
MS_LOG(EXCEPTION) << "Got unexpected keyword argument: " << kw_param_name;
if (pos_args_input_count + i > specialized_graph->parameters().size() - 1) {
MS_LOG(EXCEPTION) << "Got unexpected keyword argument: " << kw_param_name;
}
specialized_parameter_list->push_back(specialized_graph->parameters()[pos_args_input_count + i]);
} else {
ParameterPtr para = std::make_shared<Parameter>(specialized_graph);
std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]";
MS_EXCEPTION_IF_NULL(specialized_parameter_list);
auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(),
[param_name](const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
@ -169,7 +174,6 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph,
auto extract_node = specialized_graph->NewCNode(
{NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node});
kwarg_nodes.insert(param_node);
MS_EXCEPTION_IF_NULL(repl_nodes);
(void)repl_nodes->emplace(param_node, extract_node);
}
}
@ -182,7 +186,7 @@ void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph,
const std::vector<AnfNodePtr> &kwarg_keys_tuple_nodes,
const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes,
mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const {
if (has_kwarg()) {
if (has_kwarg() && !kwarg_keys_tuple_nodes.empty()) {
MS_EXCEPTION_IF_NULL(specialized_graph);
TraceGuard guard(
std::make_shared<TraceGenerateKwArg>(specialized_graph->GetVariableKwargParameter()->debug_info()));
@ -264,7 +268,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
mindspore::HashMap<AnfNodePtr, AnfNodePtr> repl_nodes;
GenerateVarParams(specialized_graph, variable_args_count, pos_args_input_count, &specialized_parameter_list,
&repl_nodes);
GenerateKwParams(specialized_graph, kwarg_list, &specialized_parameter_list, &repl_nodes);
GenerateKwParams(specialized_graph, kwarg_list, pos_args_input_count, &specialized_parameter_list, &repl_nodes);
GenerateDefaultValue(specialized_graph, specialized_parameter_list, &repl_nodes);

View File

@ -120,28 +120,30 @@ def _handle_func_args(func, *args, **kwargs):
bound_arguments.apply_defaults()
args = bound_arguments.args
kwargs = bound_arguments.kwargs
# After apply_defaults, kwargs should be empty here.
if kwargs:
raise ValueError(f"Failed to handle kwargs of {func.__name__}. Maybe you pass wrong arguments, "
f"or there is a key in kwargs that is not used as a function argument, "
f"args: {args}, kwargs: {kwargs}")
positional_args = 0
default_args = 0
has_var = False
for value in inspect.signature(func).parameters.values():
if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD:
return args
has_var = True
if value.kind is inspect.Parameter.KEYWORD_ONLY:
raise TypeError(f"Function {func.__name__}, MindSpore does not support keyword-only arg: {value}.")
if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
if value.default is inspect.Parameter.empty:
positional_args += 1
else:
default_args += 1
if has_var:
return args, kwargs
if len(args) < positional_args:
raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument, but got {len(args)}.")
if len(args) > positional_args + default_args:
raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument and {default_args} "
f"default argument, total {positional_args + default_args}, but got {len(args)}.")
return args
return args, kwargs
sys_path = list(sys.path)
@ -240,25 +242,39 @@ def _get_parameter_layout():
return layout
def _get_args_for_run(obj, args_list):
"""Get the actual input args for runtime."""
inputs = []
for i in args_list:
if isinstance(i, PythonTensor):
if i.has_init:
i.init_data()
if not i.const_arg:
inputs.append(i)
elif isinstance(i, (Tensor, CSRTensor, COOTensor)):
inputs.append(i)
elif hasattr(i, "__ms_mutable__") and getattr(i, "__ms_mutable__"):
inputs.append(i)
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
inputs.append(i)
elif hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(i, tuple) and \
_check_all_tensor(i):
inputs.append(i)
return inputs
def _handle_arg(obj, arg):
"""Handle arg for runtime .If need handle the arg, return True"""
if isinstance(arg, PythonTensor):
if arg.has_init:
arg.init_data()
if not arg.const_arg:
return arg
elif isinstance(arg, (Tensor, CSRTensor, COOTensor)):
return arg
elif hasattr(arg, "__ms_mutable__") and getattr(arg, "__ms_mutable__"):
return arg
elif context.get_context("grad_for_scalar") and isinstance(arg, (int, float)):
return arg
elif hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(arg, tuple) and \
_check_all_tensor(arg):
return arg
return None
def _get_args_for_run(obj, args, kwargs):
"""Get the actual input args and kwargs for runtime."""
new_args = []
for arg in args:
new_arg = _handle_arg(obj, arg)
if new_arg is not None:
new_args.append(new_arg)
for _, value in kwargs.items():
new_value = _handle_arg(obj, value)
if new_value is not None:
new_args.append(new_value)
return new_args
class _MindsporeFunctionExecutor:
@ -296,7 +312,7 @@ class _MindsporeFunctionExecutor:
self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
@_wrap_func
def __call__(self, *args):
def __call__(self, *args, **kwargs):
args_list = args
if self.obj is not None:
args_list = args_list[1:]
@ -304,10 +320,10 @@ class _MindsporeFunctionExecutor:
phase = ""
if context.get_context("mode") == context.PYNATIVE_MODE:
_pynative_executor.set_ms_function_compile_status(True, phase)
phase = self.compile(args_list, self.fn.__name__)
phase = self.compile(self.fn.__name__, *args_list, **kwargs)
_pynative_executor.set_ms_function_compile_status(False, phase)
else:
phase = self.compile(args_list, self.fn.__name__)
phase = self.compile(self.fn.__name__, *args_list, **kwargs)
except Exception as err:
_pynative_executor.clear_res()
raise err
@ -315,7 +331,7 @@ class _MindsporeFunctionExecutor:
if context.get_context("precompile_only"):
return None
new_inputs = self._generate_run_args(args_list)
new_inputs = self._generate_run_args(args_list, kwargs)
output = self._graph_executor(tuple(new_inputs), phase)
if context.get_context("mode") == context.PYNATIVE_MODE:
output = _pynative_executor.grad_ms_function(output, *new_inputs)
@ -331,7 +347,7 @@ class _MindsporeFunctionExecutor:
return output
def compile(self, args_list, method_name):
def compile(self, method_name, *args, **kwargs):
"""Returns pipeline for the given args."""
# Check whether hook function registered on Cell object.
if self.obj and hasattr(self.obj, "_hook_fn_registered"):
@ -340,9 +356,9 @@ class _MindsporeFunctionExecutor:
f"If you want to use hook function, please use context.set_context to set "
f"pynative mode and remove 'jit' decorator.")
# Chose dynamic shape tensors or actual input tensors as compile args.
compile_args = self._generate_compile_args(args_list)
compile_args = self._generate_compile_args(args)
# Restore the mutable attr for every arg.
compile_args = _restore_mutable_attr(args_list, compile_args)
compile_args = _restore_mutable_attr(args, compile_args)
generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + \
str(self.fn.__code__.co_firstlineno)
@ -371,7 +387,7 @@ class _MindsporeFunctionExecutor:
self.enable_tuple_broaden = self.obj.enable_tuple_broaden
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
key = self._graph_executor.generate_arguments_key(self.fn, compile_args, self.enable_tuple_broaden)
key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
phase = generate_name + '.' + str(key)
if phase in ms_compile_cache:
return phase
@ -382,11 +398,11 @@ class _MindsporeFunctionExecutor:
self._graph_executor.set_jit_config(self.jit_config_dict)
if self.obj is None:
is_compile = self._graph_executor.compile(self.fn, compile_args, phase, True)
is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, True)
else:
if isinstance(self.obj, ms.nn.Cell):
self._graph_executor.set_weights_values(self.obj.parameters_dict())
is_compile = self._graph_executor.compile(self.obj, compile_args, phase, True)
is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase, True)
if not is_compile:
raise RuntimeError("Executor compile failed.")
@ -449,17 +465,18 @@ class _MindsporeFunctionExecutor:
raise ValueError("The input args is incompatible with the args in `input_signature`!")
return compile_args
def _generate_run_args(self, args_list):
def _generate_run_args(self, args_list, kwargs):
"""
Generate input args, which are required for running.
Args:
args_list (Tuple): Actual input args.
kwargs (Dict): Actual input kwargs.
Returns:
new_inputs, new input args, which are required for running.
"""
return _get_args_for_run(self, args_list)
return _get_args_for_run(self, args_list, kwargs)
# The attributes used to identify a given object.
@ -577,14 +594,15 @@ def jit(fn=None, input_signature=None, hash_args=None, jit_config=None):
if os.getenv("MS_JIT") == '0':
return func(*args, **kwargs)
args = _handle_func_args(func, *args, **kwargs)
args, kwargs = _handle_func_args(func, *args, **kwargs)
process_obj = None
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
process_obj = args[0]
# only the function or cell instance wrapped by shard will fall into this branch
if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME:
process_obj = hash_args
out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj, jit_config)(*args)
out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj, jit_config)(*args, **kwargs)
return out
return staging_specialize
@ -1338,16 +1356,17 @@ class _CellGraphExecutor:
if "train" in phase and (enable_compile_cache is True or enable_compile_cache == "1"):
self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
def compile(self, obj, *args, phase='predict', do_convert=True, jit_config_dict=None):
def compile(self, obj, *args, phase='predict', do_convert=True, jit_config_dict=None, **kwargs):
"""
Compiles graph.
Args:
obj (Function/Cell): The function or cell instance need compile.
args (tuple): Function or cell input arguments.
phase (str): The name of compile phase. Default: 'predict'.
do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
jit_config_dict (dict): Jit config for compile. Default: None.
args (tuple): Args of the Cell object.
kwargs (dict): Kwargs of the Cell object.
Return:
Str, the full phase of the cell.
@ -1357,14 +1376,13 @@ class _CellGraphExecutor:
if not hasattr(obj, obj.__parse_method__):
raise AttributeError(
'The class {} dose not have method {}'.format(obj.__class__.__name__, obj.__parse_method__))
args_list = args
self.enable_tuple_broaden = False
if hasattr(obj, "enable_tuple_broaden"):
self.enable_tuple_broaden = obj.enable_tuple_broaden
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
key = self._graph_executor.generate_arguments_key(obj, args_list, self.enable_tuple_broaden)
key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
obj.arguments_key = str(key)
phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
@ -1374,14 +1392,14 @@ class _CellGraphExecutor:
obj.check_names()
_check_full_batch()
self._set_dataset_mode(args_list)
self._set_dataset_mode(args)
self._set_compile_cache_dep_files(phase)
enable_ge = context.get_context("enable_ge")
self._graph_executor.set_weights_values(obj.parameters_dict())
if jit_config_dict:
self._graph_executor.set_jit_config(jit_config_dict)
result = self._graph_executor.compile(obj, args_list, phase, self._use_vm_mode())
result = self._graph_executor.compile(obj, args, kwargs, phase, self._use_vm_mode())
obj.compile_cache.add(phase)
if not result:
raise RuntimeError("Executor compile failed.")
@ -1458,6 +1476,8 @@ class _CellGraphExecutor:
Run the specific graph.
Args:
obj (Cell): The cell object.
args (tuple): Args of the Cell object.
phase (str): The phase name. Default: 'predict'.
Returns:

View File

@ -443,36 +443,38 @@ class Cell(Cell_):
output = self._run_forward_hook(cast_inputs, output)
return output
def _check_construct_args(self, *inputs, **kwargs):
def _check_construct_args(self, *args):
"""Check the args needed by the function construct"""
if kwargs:
raise ValueError(f"For 'Cell', expect no kwargs here, maybe you pass wrong arguments, "
f"or there is a key in kwargs that is not used as a function argument. "
f"args: {inputs}, kwargs: {kwargs}")
positional_args = 0
default_args = 0
has_var = False
for value in inspect.signature(self.construct).parameters.values():
if value.kind is inspect.Parameter.VAR_POSITIONAL or value.kind is inspect.Parameter.VAR_KEYWORD:
return
has_var = True
if value.kind is inspect.Parameter.KEYWORD_ONLY:
raise TypeError(f"For the method 'construct', MindSpore does not support keyword-only arg: {value}.")
if value.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
if value.default is inspect.Parameter.empty:
positional_args += 1
else:
default_args += 1
if len(inputs) < positional_args:
if has_var:
return
if len(args) < positional_args:
raise TypeError(f"For 'Cell', the function construct requires {positional_args} positional argument, "
f"but got {len(inputs)}. When using set_inputs, please make sure that all networks "
f"but got {len(args)}. When using set_inputs, please make sure that all networks "
f"and loss functions are configured with set_inputs.")
if len(inputs) > positional_args + default_args:
if len(args) > positional_args + default_args:
construct_inputs_names = self.construct.__code__.co_varnames
if 'self' not in construct_inputs_names:
raise TypeError(f"For 'Cell', the method 'construct' must have parameter 'self'. ")
raise TypeError(f"For 'Cell', the function construct requires {positional_args} positional argument and "
f"{default_args} default argument, total {positional_args + default_args}, "
f"but got {len(inputs)}.")
f"but got {len(args)}.")
def _hook_fn_registered(self):
'''Hook function in graph mode'''
@ -615,7 +617,7 @@ class Cell(Cell_):
def __call__(self, *args, **kwargs):
if self.__class__.construct is Cell.construct:
raise AttributeError("For 'Cell', the method 'construct' is not defined. ")
raise AttributeError("For 'Cell', the method 'construct' is not defined.")
if kwargs:
bound_arguments = inspect.signature(self.construct).bind(*args, **kwargs)
@ -625,11 +627,11 @@ class Cell(Cell_):
# Run in Graph mode.
if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
self._check_construct_args(*args, **kwargs)
self._check_construct_args(*args)
if self._hook_fn_registered():
logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
f"function, please use context.set_context to set pynative mode.")
out = self.compile_and_run(*args)
out = self.compile_and_run(*args, **kwargs)
return out
# Run in PyNative mode.
@ -918,24 +920,25 @@ class Cell(Cell_):
return self._dynamic_shape_inputs
def compile(self, *inputs):
def compile(self, *args, **kwargs):
"""
Compile Cell as a computation graph, the input must be consistent with the input defined in construct.
Args:
inputs (tuple): Inputs of the Cell object.
args (tuple): Args of the Cell object.
kwargs (dict): Kwargs of the Cell object.
"""
if self._dynamic_shape_inputs is None or self._dynamic_shape_inputs[0] is None:
_cell_graph_executor.compile(self, *inputs, phase=self.phase,
jit_config_dict=self._jit_config_dict)
_cell_graph_executor.compile(self, phase=self.phase,
jit_config_dict=self._jit_config_dict, *args, **kwargs)
else:
self._check_compile_dynamic_shape(*inputs)
self._check_compile_dynamic_shape(*args)
self.saved_dynamic_shape = self._dynamic_shape_inputs
_cell_graph_executor.compile(self, *self._dynamic_shape_inputs, phase=self.phase,
jit_config_dict=self._jit_config_dict)
jit_config_dict=self._jit_config_dict, **kwargs)
logger.debug("Compiled Graph with dynamic shape")
def compile_and_run(self, *inputs):
def compile_and_run(self, *args, **kwargs):
"""
Compile and run Cell, the input must be consistent with the input defined in construct.
@ -943,15 +946,16 @@ class Cell(Cell_):
It is not recommended to call directly.
Args:
inputs (tuple): Inputs of the Cell object.
args (tuple): Args of the Cell object.
kwargs (dict): Kwargs of the Cell object.
Returns:
Object, the result of executing.
"""
self.compile(*inputs)
self.compile(*args, **kwargs)
new_inputs = _get_args_for_run(self, inputs)
return _cell_graph_executor(self, *new_inputs, phase=self.phase)
new_args = _get_args_for_run(self, args, kwargs)
return _cell_graph_executor(self, *new_args, phase=self.phase)
def auto_parallel_compile_and_run(self):
"""
@ -1046,7 +1050,7 @@ class Cell(Cell_):
f"but got type {type(child_cell)}.")
self._cells[child_name] = child_cell
def construct(self, *inputs, **kwargs):
def construct(self, *args, **kwargs):
"""
Defines the computation to be performed. This method must be overridden by all subclasses.
@ -1054,7 +1058,7 @@ class Cell(Cell_):
It is not supported currently that inputs contain both tuple and non-tuple types at same time.
Args:
inputs (tuple): Tuple of variable parameters.
args (tuple): Tuple of variable parameters.
kwargs (dict): Dictionary of variable keyword parameters.
Returns:
@ -2303,13 +2307,13 @@ class GraphCell(Cell):
def construct(self, *inputs):
return self.graph(*inputs)
def __call__(self, *inputs):
def __call__(self, *args, **kwargs):
self.phase = "graph_load_from_mindir"
self._add_attr("graph_load_from_mindir", self.graph)
if not self.obf_random_seed:
return self.compile_and_run(*inputs)
return self.compile_and_run(*args, **kwargs)
append_input = Tensor((numpy.ones((1, 1)) * self._branch_control_input).astype(numpy.int32))
return self.compile_and_run(*inputs, append_input)
return self.compile_and_run(*args, append_input, **kwargs)
def _check_param_list_tuple(value):

View File

@ -565,17 +565,17 @@ class _Grad(GradOperation_):
dynamic_shape_inputs = fn.get_inputs()
if self.get_by_position:
@jit(input_signature=dynamic_shape_inputs)
def after_grad(*args):
return grad_(fn, weights, grad_position)(*args)
def after_grad(*args, **kwargs):
return grad_(fn, weights, grad_position)(*args, **kwargs)
else:
if self.get_by_list:
@jit(input_signature=dynamic_shape_inputs)
def after_grad(*args):
return grad_(fn, weights)(*args)
def after_grad(*args, **kwargs):
return grad_(fn, weights)(*args, **kwargs)
else:
@jit(input_signature=dynamic_shape_inputs)
def after_grad(*args):
return grad_(fn)(*args)
def after_grad(*args, **kwargs):
return grad_(fn)(*args, **kwargs)
elif self.pynative_:
@_wrap_func
def after_grad(*args, **kwargs):
@ -663,8 +663,8 @@ class _Vmap(VmapOperation_):
vmap_ = self
@jit
def after_vmap(*args):
return vmap_(fn, in_axes, out_axes)(*args)
def after_vmap(*args, **kwargs):
return vmap_(fn, in_axes, out_axes)(*args, **kwargs)
self.vmap_fn = after_vmap
self.fn = fn

View File

@ -1455,7 +1455,7 @@ def _msfunc_info(net, *inputs):
# pylint: disable=protected-access
net_dict = OrderedDict()
_ms_func_executor = _MindsporeFunctionExecutor(net, time.time() * 1e9)
graph_id = _ms_func_executor.compile(args_list=inputs, method_name=net.__name__)
graph_id = _ms_func_executor.compile(net.__name__, *inputs)
mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir')
params = _ms_func_executor._graph_executor.get_params(graph_id)
for name, value in params.items():

View File

@ -17,7 +17,7 @@ import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor, Parameter, ParameterTuple
from mindspore import Tensor, Parameter, ParameterTuple, jit
from mindspore.ops import composite as C
from mindspore.ops import operations as P
import mindspore.ops as ops
@ -320,3 +320,82 @@ def test_grad_parameter_as_input_and_fv2(mode):
assert np.array_equal(a[0][0].asnumpy(), b[0][0].asnumpy())
assert np.array_equal(a[1][0].asnumpy(), b[1][0].asnumpy())
assert np.array_equal(a[1][1].asnumpy(), b[1][1].asnumpy())
tensor1 = Tensor([1])
tensor2 = Tensor([2])
tensor3 = Tensor([3])
tensor4 = Tensor([4])
tensor5 = Tensor([5])
tensor6 = Tensor([6])
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cell_mixed_arguments():
"""
Feature: Support kwargs for top graph.
Description: Mixed arguments for cell.
Expectation: No exception.
"""
class FNet(nn.Cell):
def construct(self, a, *args, **kwargs):
x = a + args[0] + args[1] + kwargs["d"]
return x
context.set_context(mode=context.GRAPH_MODE)
net = FNet()
assert net(tensor1, tensor2, tensor3, b=tensor4, c=tensor5, d=tensor6).asnumpy() == [12]
assert net(1, 2, 3, d=tensor6).asnumpy() == [12]
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cell_mixed_arguments_with_grad():
"""
Feature: Support kwargs for top graph.
Description: Mixed arguments for jit function.
Expectation: No exception.
"""
class FNet(nn.Cell):
def construct(self, *args, **kwargs):
x = args[0] + args[1] - kwargs["d"]
return x
class GNet(nn.Cell):
def __init__(self, net):
super(GNet, self).__init__()
self.net = net
self.grad_op = ops.GradOperation()
def construct(self, *args, **kwargs):
gradient_function = self.grad_op(self.net)
return gradient_function(*args, **kwargs)
context.set_context(mode=context.GRAPH_MODE)
grad_net = GNet(FNet())
assert grad_net(tensor1, tensor2, tensor3, d=tensor4, e=tensor5).asnumpy() == [1]
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_jit_mixed_arguments():
"""
Feature: Support kwargs for top graph.
Description: Mixed arguments for jit function.
Expectation: No exception.
"""
@jit
def func(a, *args, **kwargs):
x = a + args[0] + args[1] + kwargs["d"]
return x
context.set_context(mode=context.GRAPH_MODE)
assert func(tensor1, tensor2, tensor3, b=tensor4, c=tensor5, d=tensor6).asnumpy() == [12]
assert func(1, 2, 3, d=tensor6).asnumpy() == [12]

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -22,8 +22,10 @@ from mindspore.nn.optim import Momentum
from mindspore.common.api import jit
from mindspore.common import Parameter, ParameterTuple
import mindspore.context as context
context.set_context(mode=context.PYNATIVE_MODE)
@jit
def ConvBnReLU(x):
conv = nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
@ -36,6 +38,7 @@ def ConvBnReLU(x):
return x
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -262,6 +265,7 @@ def test_pynative_ms_function():
assert np.allclose(out_a[0][0].asnumpy(), out_b[0][0].asnumpy(), 0.0001, 0.0001)
assert np.allclose(out_a[1][0].asnumpy(), out_b[1][0].asnumpy(), 0.0001, 0.0001)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@ -290,6 +294,7 @@ def test_pynative_ms_function_mix_execute():
output = net(a, b)
assert output == 8
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@ -324,6 +329,7 @@ def test_pynative_ms_function_empty_graph():
output = net()
assert output.asnumpy() == 10
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -363,6 +369,7 @@ def test_pynative_ms_function_control_flow_if_break():
output = net(x, y, z)
assert (output.asnumpy() == z.asnumpy() * 4).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_ascend_training
@ -410,7 +417,6 @@ def test_pynative_ms_function_with_tuple_inputs():
new_grads.append(grad + 1)
return new_grads
x = Tensor(np.ones([2, 2]), dtype=ms.int32)
y = Tensor(np.ones([2, 2]), dtype=ms.int32)
net = Net()
@ -480,7 +486,6 @@ def test_pynative_ms_function_with_kwargs_inputs():
def foo(x, **kwargs):
return x + kwargs.get('y')
with pytest.raises(ValueError):
x = Tensor(3, dtype=ms.int32)
data = {"y": 1}
foo(x, **data)
x = Tensor(3, dtype=ms.int32)
data = {"y": 1}
assert foo(x, **data).asnumpy() == [4]

View File

@ -83,20 +83,20 @@ def test_ms_function_tensor_compile_phase1():
ms_create_time = int(time.time() * 1e9)
_ms_function_executor = _MindsporeFunctionExecutor(fn, ms_create_time)
# The ms_function makes the tensor inputs mutable by default
phase1 = _ms_function_executor.compile((x, y), "fn")
phase2 = _ms_function_executor.compile((p, q), "fn")
phase1 = _ms_function_executor.compile("fn", x, y)
phase2 = _ms_function_executor.compile("fn", p, q)
assert phase1 != phase2
# mutable api
phase1 = _ms_function_executor.compile((mutable(x), mutable(y)), "fn")
phase2 = _ms_function_executor.compile((mutable(p), mutable(q)), "fn")
phase1 = _ms_function_executor.compile("fn", mutable(x), mutable(y))
phase2 = _ms_function_executor.compile("fn", mutable(p), mutable(q))
assert phase1 == phase2
# set_mutable api of Tensor
x.set_const_arg(False)
y.set_const_arg(False)
p.set_const_arg(False)
q.set_const_arg(False)
phase1 = _ms_function_executor.compile((x, y), "fn")
phase2 = _ms_function_executor.compile((p, q), "fn")
phase1 = _ms_function_executor.compile("fn", x, y)
phase2 = _ms_function_executor.compile("fn", p, q)
assert phase1 == phase2
@ -157,20 +157,20 @@ def test_ms_function_tensor_compile_phase2():
q = Tensor([[0.01, 3.0, 1.1], [1.0, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
ms_create_time = int(time.time() * 1e9)
_ms_function_executor = _MindsporeFunctionExecutor(fn, ms_create_time)
phase1 = _ms_function_executor.compile((x, y), "fn")
phase2 = _ms_function_executor.compile((p, q), "fn")
phase1 = _ms_function_executor.compile("fn", x, y)
phase2 = _ms_function_executor.compile("fn", p, q)
assert phase1 == phase2
# Set const arg.
x.set_const_arg()
y.set_const_arg()
p.set_const_arg()
q.set_const_arg()
phase1 = _ms_function_executor.compile((x, y), "fn")
phase2 = _ms_function_executor.compile((p, q), "fn")
phase1 = _ms_function_executor.compile("fn", x, y)
phase2 = _ms_function_executor.compile("fn", p, q)
assert phase1 != phase2
# mutable api
phase1 = _ms_function_executor.compile((mutable(x), mutable(y)), "fn")
phase2 = _ms_function_executor.compile((mutable(p), mutable(q)), "fn")
phase1 = _ms_function_executor.compile("fn", mutable(x), mutable(y))
phase2 = _ms_function_executor.compile("fn", mutable(p), mutable(q))
assert phase1 == phase2

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore.context as context
import mindspore.ops.composite as C
from mindspore import Tensor, Parameter
from mindspore import Tensor, Parameter, jit
from mindspore import nn
from mindspore.nn import Cell
from mindspore.ops import operations as P
@ -323,6 +324,7 @@ def test_args_kwarg_not_used():
Description: Function with unused parameters which are varargs and kwargs.
Expectation: compile success and result == 0
"""
class Net(Cell):
def trivial(self, *args, **kwargs):
return 0
@ -344,6 +346,7 @@ def test_args_kwonlyargs_1_kwarg_not_used():
Description: Function with unused parameters which are varargs, 1 kwonlyargs and kwargs.
Expectation: compile success and result == 0
"""
class Net(Cell):
def trivial(self, *args, only1=3, **kwargs):
return 0
@ -365,6 +368,7 @@ def test_args_kwonlyargs_2_kwarg_not_used():
Description: Function with unused parameters which are varargs, 2 kwonlyargs and kwargs.
Expectation: compile success and result == 0
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return 0
@ -386,6 +390,7 @@ def test_args_1_used_kwonlyargs_kwarg_not_used():
Description: Function with unused parameters which are 1 kwonlyargs and kwargs.
Expectation: compile success and result == x
"""
class Net(Cell):
def trivial(self, *args, only1=3, **kwargs):
return args[0]
@ -407,6 +412,7 @@ def test_args_2_used_kwonlyargs_kwarg_not_used():
Description: Function with unused parameters which are 1 kwonlyargs and kwargs.
Expectation: compile success and result == y
"""
class Net(Cell):
def trivial(self, *args, only1=3, **kwargs):
return args[1]
@ -428,6 +434,7 @@ def test_kwonlyargs_1_used_args_kwarg_not_used():
Description: Function with unused parameters which are varargs and kwargs.
Expectation: compile success and result == only1
"""
class Net(Cell):
def trivial(self, *args, only1=3, **kwargs):
return only1
@ -449,6 +456,7 @@ def test_kwonlyargs_2_used_args_kwarg_not_used():
Description: Function with unused parameters which are varargs and kwargs.
Expectation: compile success and result == only2
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return only2
@ -470,6 +478,7 @@ def test_kwarg_used_args_kwonlyargs_not_used():
Description: Function with unused parameters which are varargs and kwonlyargs.
Expectation: compile success and result == kw1
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return kwargs["kw1"]
@ -490,6 +499,7 @@ def test_args_1_kwonlyargs_1_used_kwarg_not_used():
Description: Function with unused parameters which are kwargs.
Expectation: compile success and result == (x, 3)
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return (args[0], only1)
@ -510,6 +520,7 @@ def test_args_2_kwonlyargs_1_used_kwarg_not_used():
Description: Function with unused parameters which are kwargs.
Expectation: compile success and result == (x, y, 3)
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return (args[0], args[1], only1)
@ -530,6 +541,7 @@ def test_args_2_kwonlyargs_2_used_kwarg_not_used():
Description: Function with unused parameters which are kwargs.
Expectation: compile success and result == (x, y, only1, only2)
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return (args[0], args[1], only1, only2)
@ -550,6 +562,7 @@ def test_kwonlyargs_1_kwarg_used_args_not_used():
Description: Function with unused parameters which are varargs.
Expectation: compile success and result == (y, kw1)
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return (only1, kwargs["kw1"])
@ -570,6 +583,7 @@ def test_kwonlyargs_2_kwarg_used_args_not_used():
Description: Function with unused parameters which are varargs.
Expectation: compile success and result == (only1, only2, kw1)
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return (only1, only2, kwargs["kw1"])
@ -590,6 +604,7 @@ def test_args_1_kwarg_used_kwonlyargs_not_used():
Description: Function with unused parameters which are kwonlyargs.
Expectation: compile success and result == (x, kw1)
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return (args[0], kwargs["kw1"])
@ -610,6 +625,7 @@ def test_args_2_kwarg_used_kwonlyargs_not_used():
Description: Function with unused parameters which are kwonlyargs.
Expectation: compile success and result == (x, y, kw1)
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return (args[0], args[1], kwargs["kw1"])
@ -630,6 +646,7 @@ def test_args_1_kwonlyargs_1_kwarg_used():
Description: Function with unused parameters which is kwonlyarg.
Expectation: compile success and result == (x, only1, kw1)
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return (args[0], only1, kwargs["kw1"])
@ -650,6 +667,7 @@ def test_args_2_kwonlyargs_2_kwarg_used():
Description: Function without unused parameters.
Expectation: compile success and result == (x, y, only1, only2, kw1)
"""
class Net(Cell):
def trivial(self, *args, only1=3, only2=4, **kwargs):
return (args[0], args[1], only1, only2, kwargs["kw1"])
@ -662,3 +680,305 @@ def test_args_2_kwonlyargs_2_kwarg_used():
x = 1
y = 2
assert net(x, y) == (x, y, 3, 4, 5)
def test_cell_keyword_argument():
"""
Feature: Support kwargs for top graph.
Description: Only positional arguments.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, a, b):
return a * b
net = Net()
assert net(2, b=3) == 6
assert net(a=2, b=3) == 6
assert net(b=3, a=2) == 6
def test_cell_default_argument():
"""
Feature: Support kwargs for top graph.
Description: Positional arguments with default values.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, x, y=3, z=4):
return x ** y + z
net = Net()
assert net(2) == 12
assert net(2, 1) == 6
assert net(2, 3, 2) == 10
assert net(y=1, z=3, x=2) == 5
def test_cell_args1():
"""
Feature: Support kwargs for top graph.
Description: Only varargs.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, *args):
x = 0
for arg in args:
x = x + arg
return x
net = Net()
assert net(1, 2, 3) == 6
def test_cell_args2():
"""
Feature: Support kwargs for top graph.
Description: Positional arguments and varargs.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, x, *args):
for arg in args:
x = x + arg
return x
net = Net()
assert net(1, 2, 3) == 6
def test_cell_kwargs1():
"""
Feature: Support kwargs for top graph.
Description: Only kwarg.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, **kwargs):
return kwargs["a"] + kwargs["b"]
net = Net()
assert net(a=1, b=2, c=3) == 3
def test_cell_kwargs2():
"""
Feature: Support kwargs for top graph.
Description: Positional arguments and kwarg.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, x, **kwargs):
return x + kwargs["a"] - kwargs["b"]
net = Net()
assert net(1, a=2, b=3) == 0
def test_cell_args_kwargs():
"""
Feature: Support kwargs for top graph.
Description: Vararg and kwarg.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, *args, **kwargs):
x = args[0] + args[1] - kwargs["c"] + kwargs["d"]
return x
context.set_context(mode=context.GRAPH_MODE)
net = Net()
assert net(1, 2, c=3, d=4) == 4
@pytest.mark.skip(reason='kwonly not support')
def test_cell_kwonly1():
"""
Feature: Support kwargs for top graph.
Description: Only kwonly arguments.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, *, a, b):
x = a + b
return x
context.set_context(mode=context.GRAPH_MODE)
net = Net()
assert net(a=1, b=2) == 3
@pytest.mark.skip(reason='kwonly not support')
def test_cell_kwonly2():
"""
Feature: Support kwargs for top graph.
Description: Positional args and kwonly arguments with default values.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, a, *, b, c=3):
x = a + b - c
return x
context.set_context(mode=context.GRAPH_MODE)
net = Net()
assert net(1, b=2, c=3) == 0
assert net(1, b=2) == 0
def test_cell_mixed_arguments1():
"""
Feature: Support kwargs for top graph.
Description: Mixed arguments.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, a, b, c=3, *args, **kwargs):
x = a + b - c + args[0] - args[1] + kwargs["d"]
return x
context.set_context(mode=context.GRAPH_MODE)
net = Net()
assert net(1, 2, 3, 4, 5, d=6, e=7) == 5
@pytest.mark.skip(reason='kwonly not support')
def test_cell_mixed_arguments2():
"""
Feature: Support kwargs for top graph.
Description: Mixed arguments.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, a, *args, b, c=1, **kwargs):
x = a + args[0] - args[1] + b - c + kwargs["d"]
return x
context.set_context(mode=context.GRAPH_MODE)
net = Net()
assert net(1, 2, 3, b=4, c=5, d=6) == 5
assert net(1, 2, 3, b=4, d=6) == 9
@pytest.mark.skip(reason='kwonly not support')
def test_cell_mixed_arguments3():
"""
Feature: Support kwargs for top graph.
Description: Mixed arguments.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, a, *, b, c=1, **kwargs):
x = a + b - c + kwargs["d"]
return x
context.set_context(mode=context.GRAPH_MODE)
net = Net()
assert net(1, b=2, c=3, d=4) == 4
assert net(1, b=4, d=6) == 10
@pytest.mark.skip(reason='kwonly not support')
def test_cell_mixed_arguments_with_dict_input():
"""
Feature: Support kwargs for top graph.
Description: Mixed arguments with dictionary.
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self, a, *args, b, c=1, **kwargs):
x = a["item0"] + args[0] + args[1]["item1"] + b["item2"] + c + kwargs["d"]["item3"]
return x
context.set_context(mode=context.GRAPH_MODE)
net = Net()
assert net({"item0": 1}, 2, {"item1": 3}, b={"item2": 4}, c=5, d={"item3": 6}) == 21
def test_cell_mixed_arguments_with_sub_cell():
"""
Feature: Support kwargs for top graph.
Description: Mixed arguments with sub cell.
Expectation: No exception.
"""
class SubNet(nn.Cell):
def construct(self, a, *args, b, c=1, **kwargs):
x = a + args[0] + args[1] + b + c + kwargs["d"]
return x
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.subnet = SubNet()
def construct(self, a, arg0, arg1, b, c, d):
x = self.subnet(a, arg0, arg1, b=b, c=c, d=d)
return x
context.set_context(mode=context.GRAPH_MODE)
net = Net()
assert net(1, 2, 3, 4, 5, 6) == 21
def test_jit_kwargs():
"""
Feature: Support kwargs for top graph.
Description: Vararg and kwarg for jit function.
Expectation: No exception.
"""
@jit
def func(*args, **kwargs):
x = args[0] + args[1] - kwargs["d"]
return x
context.set_context(mode=context.GRAPH_MODE)
assert func(1, 2, c=3, d=4) == -1
assert func(1, 2, d=4) == -1
assert func(1, 2, d=4, e=5) == -1
assert func(1, 2, 2.1, 2.2, d=4, e=5, f=6) == -1
context.set_context(mode=context.PYNATIVE_MODE)
assert func(1, 2, c=3, d=4) == -1
assert func(1, 2, d=4) == -1
assert func(1, 2, d=4, e=5) == -1
assert func(1, 2, 2.1, 2.2, d=4, e=5, f=6) == -1
@pytest.mark.skip(reason='kwonly not support')
def test_jit_mixed_arguments():
"""
Feature: Support kwargs for top graph.
Description: Vararg and kwarg for jit function.
Expectation: No exception.
"""
@jit
def func(a, *args, b, c=5, **kwargs):
x = a + args[0] - args[1] + b - c + kwargs["d"]
return x
context.set_context(mode=context.GRAPH_MODE)
assert func(1, 2, 3, b=4, c=5, d=6) == 5
assert func(1, 2, 3, b=4, d=6) == 5
assert func(1, 2, 3, 4, b=5, c=6, d=7) == 6
context.set_context(mode=context.PYNATIVE_MODE)
assert func(1, 2, 3, b=4, c=5, d=6) == 5
assert func(1, 2, 3, b=4, d=6) == 5
assert func(1, 2, 3, 4, b=5, c=6, d=7) == 6