diff --git a/mindspore/_extends/parse/parser.py b/mindspore/_extends/parse/parser.py index ffa2b8ec016..85d17b27f5f 100644 --- a/mindspore/_extends/parse/parser.py +++ b/mindspore/_extends/parse/parser.py @@ -278,16 +278,42 @@ def _is_dataclass_instance(obj): return is_dataclass(obj) and not isinstance(obj, type) -def create_obj_instance(cls_type, args_tuple=None): +def _convert_tuple_to_args_kwargs(params): + args = tuple() + kwargs = dict() + for param in params: + if isinstance(param, dict): + kwargs.update(param) + else: + args += (param,) + return (args, kwargs) + + +def create_obj_instance(cls_type, params=None): """Create python instance.""" + if not isinstance(cls_type, type): + logger.warning(f"create_obj_instance(), cls_type is not a type, cls_type: {cls_type}") + return None + + # Check the type, now only support nn.Cell and Primitive. obj = None - if isinstance(cls_type, type): - # check the type, now only support nn.Cell and Primitive - if issubclass(cls_type, (nn.Cell, ops.Primitive)): - if args_tuple is not None: - obj = cls_type(*args_tuple) - else: - obj = cls_type() + if issubclass(cls_type, (nn.Cell, ops.Primitive)): + # Check arguments, only support *args or **kwargs. + if params is None: + obj = cls_type() + elif isinstance(params, tuple): + args, kwargs = _convert_tuple_to_args_kwargs(params) + logger.debug(f"create_obj_instance(), args: {args}, kwargs: {kwargs}") + if args and kwargs: + obj = cls_type(*args, **kwargs) + elif args: + obj = cls_type(*args) + elif kwargs: + obj = cls_type(**kwargs) + # If invalid parameters. + if obj is None: + raise ValueError(f"When call 'create_instance', the parameter should be *args or **kwargs, " + f"but got {params.__class__.__name__}, params: {params}") return obj diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index 937908a88cb..72cf956b88c 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -481,7 +481,7 @@ std::vector GetDataConverters() { } // namespace bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, const TypePtr &dtype) { - // check parameter valid + // Check parameter valid if (data == nullptr) { MS_LOG(ERROR) << "Data is null pointer"; return false; @@ -503,7 +503,7 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature return converted != nullptr; } -// convert data to graph +// Convert data to graph FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) { std::vector results = data_converter::GetObjKey(obj); std::string obj_id = results[0] + python_mod_get_parse_method; @@ -565,7 +565,7 @@ std::vector GetObjKey(const py::object &obj) { return {py::cast(obj_tuple[0]), py::cast(obj_tuple[1])}; } -// get obj detail type +// Get obj detail type ResolveTypeDef GetObjType(const py::object &obj) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); auto obj_type = @@ -573,7 +573,7 @@ ResolveTypeDef GetObjType(const py::object &obj) { return obj_type; } -// get class instance detail type +// Get class instance detail type. ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); auto class_type = @@ -581,26 +581,27 @@ ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) { return class_type; } -// check the object is Cell Instance +// Check the object is Cell Instance. bool IsCellInstance(const py::object &obj) { auto class_type = GetClassInstanceType(obj); bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL); return isCell; } -// create the python class instance -py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms) { +// Create the python class instance. +py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - return params.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type) - : python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params); + // `args_kwargs` maybe a tuple(*args), tuple(**kwargs), or tuple(*args, **kwargs). + return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type) + : python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, args_kwargs); } -// Generate an appropriate name and set to graph debuginfo -// character <> can not used in the dot file, so change to another symbol +// Generate an appropriate name and set to graph debuginfo, +// character <> can not used in the dot file, so change to another symbol. void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph->debug_info()); - // set detail name info of function + // Set detail name info of function std::ostringstream oss; for (size_t i = 0; i < name.size(); i++) { if (name[i] == '<') { @@ -629,7 +630,7 @@ void ClearObjectCache() { static std::unordered_map g_dataClassToClass = {}; -// parse dataclass to mindspore Class type +// Parse dataclass to mindspore Class type ClassPtr ParseDataClass(const py::object &cls_obj) { std::string cls_name = py::cast(python_adapter::GetPyObjAttr(cls_obj, "__name__")); std::string cls_module = py::cast(python_adapter::GetPyObjAttr(cls_obj, "__module__")); diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.h b/mindspore/ccsrc/pipeline/jit/parse/data_converter.h index e279069d730..3cfc6a84352 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.h +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.h @@ -44,7 +44,7 @@ ResolveTypeDef GetObjType(const py::object &obj); ClassInstanceTypeDef GetClassInstanceType(const py::object &obj); bool IsCellInstance(const py::object &obj); -py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms); +py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs); void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name); ValuePtr PyDataToValue(const py::object &obj); void ClearObjectCache(); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 5282e65d913..9e89322121b 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -1107,7 +1107,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator { MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty"; } - // get the type parameter + // Get the type parameter. MS_EXCEPTION_IF_NULL(args_spec_list[0]); TypePtr type = args_spec_list[0]->GetTypeTrack(); if (type->type_id() != kMetaTypeTypeType) { @@ -1131,17 +1131,17 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator { auto class_type = type_obj->obj(); MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << "."; - // get the create instance obj's parameters - pybind11::tuple params = GetParameters(args_spec_list); + // Get the create instance obj's parameters, `params` may contain tuple(args, kwargs). + py::tuple params = GetParameters(args_spec_list); - // create class instance + // Create class instance. auto obj = parse::data_converter::CreatePythonObject(class_type, params); if (py::isinstance(obj)) { MS_LOG(EXCEPTION) << "Create python object" << py::str(class_type) << " failed, only support create Cell or Primitive object."; } - // process the object + // Process the object. ValuePtr converted_ret = nullptr; bool converted = parse::ConvertData(obj, &converted_ret, true); if (!converted) { diff --git a/mindspore/ccsrc/utils/convert_utils_py.cc b/mindspore/ccsrc/utils/convert_utils_py.cc index fd1c751eb64..d59ea447e96 100644 --- a/mindspore/ccsrc/utils/convert_utils_py.cc +++ b/mindspore/ccsrc/utils/convert_utils_py.cc @@ -167,6 +167,14 @@ py::object ValuePtrToPyData(const ValuePtr &value) { } else if (value->isa() || value->isa() || value->isa() || value->isa()) { // FuncGraph is not used in the backend, return None ret = py::none(); + } else if (value->isa()) { + auto abs_keyword_arg = value->ToAbstract()->cast(); + auto key = abs_keyword_arg->get_key(); + auto val = abs_keyword_arg->get_arg()->BuildValue(); + auto py_value = ValuePtrToPyData(val); + auto kwargs = py::kwargs(); + kwargs[key.c_str()] = py_value; + ret = kwargs; } else { MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData."; } diff --git a/tests/ut/python/pipeline/parse/test_create_obj.py b/tests/ut/python/pipeline/parse/test_create_obj.py index 201aa0ae2db..6c440936d56 100644 --- a/tests/ut/python/pipeline/parse/test_create_obj.py +++ b/tests/ut/python/pipeline/parse/test_create_obj.py @@ -27,7 +27,7 @@ import numpy as np import mindspore.nn as nn from mindspore import context from mindspore.common.api import ms_function -from mindspore.common.tensor import Tensor +from mindspore.common import Tensor, Parameter from mindspore.ops import operations as P from ...ut_filter import non_graph_engine @@ -48,7 +48,7 @@ class Net(nn.Cell): return x -# Test: creat CELL OR Primitive instance on construct +# Test: Create Cell OR Primitive instance on construct @non_graph_engine def test_create_cell_object_on_construct(): """ test_create_cell_object_on_construct """ @@ -65,7 +65,7 @@ def test_create_cell_object_on_construct(): log.debug("finished test_create_object_on_construct") -# Test: creat CELL OR Primitive instance on construct +# Test: Create Cell OR Primitive instance on construct class Net1(nn.Cell): """ Net1 definition """ @@ -92,7 +92,7 @@ def test_create_primitive_object_on_construct(): log.debug("finished test_create_object_on_construct") -# Test: creat CELL OR Primitive instance on construct use many parameter +# Test: Create Cell OR Primitive instance on construct use many parameter class NetM(nn.Cell): """ NetM definition """ @@ -120,7 +120,7 @@ class NetC(nn.Cell): return x -# Test: creat CELL OR Primitive instance on construct +# Test: Create Cell OR Primitive instance on construct @non_graph_engine def test_create_cell_object_on_construct_use_many_parameter(): """ test_create_cell_object_on_construct_use_many_parameter """ @@ -135,3 +135,60 @@ def test_create_cell_object_on_construct_use_many_parameter(): print(np1) print(out_me1) log.debug("finished test_create_object_on_construct") + + +class NetD(nn.Cell): + """ NetD definition """ + + def __init__(self): + super(NetD, self).__init__() + + def construct(self, x, y): + concat = P.Concat(axis=1) + return concat((x, y)) + + +# Test: Create Cell OR Primitive instance on construct +@non_graph_engine +def test_create_primitive_object_on_construct_use_kwargs(): + """ test_create_primitive_object_on_construct_use_kwargs """ + log.debug("begin test_create_primitive_object_on_construct_use_kwargs") + context.set_context(mode=context.GRAPH_MODE) + x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32)) + y = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32)) + net = NetD() + net(x, y) + log.debug("finished test_create_primitive_object_on_construct_use_kwargs") + + +class NetE(nn.Cell): + """ NetE definition """ + + def __init__(self): + super(NetE, self).__init__() + self.w = Parameter(Tensor(np.ones([16, 16, 3, 3]).astype(np.float32)), name='w') + + def construct(self, x): + out_channel = 16 + kernel_size = 3 + conv2d = P.Conv2D(out_channel, + kernel_size, + 1, + pad_mode='valid', + pad=0, + stride=1, + dilation=1, + group=1) + return conv2d(x, self.w) + + +# Test: Create Cell OR Primitive instance on construct +@non_graph_engine +def test_create_primitive_object_on_construct_use_args_and_kwargs(): + """ test_create_primitive_object_on_construct_use_args_and_kwargs """ + log.debug("begin test_create_primitive_object_on_construct_use_args_and_kwargs") + context.set_context(mode=context.GRAPH_MODE) + inputs = Tensor(np.ones([1, 16, 16, 16]).astype(np.float32)) + net = NetE() + net(inputs) + log.debug("finished test_create_primitive_object_on_construct_use_args_and_kwargs")