forked from mindspore-Ecosystem/mindspore
Support pass args or|and kwargs for OP CreateInstance.
This commit is contained in:
parent
42fef2b09c
commit
cd7f7d40fb
|
@ -278,16 +278,42 @@ def _is_dataclass_instance(obj):
|
|||
return is_dataclass(obj) and not isinstance(obj, type)
|
||||
|
||||
|
||||
def create_obj_instance(cls_type, args_tuple=None):
|
||||
def _convert_tuple_to_args_kwargs(params):
|
||||
args = tuple()
|
||||
kwargs = dict()
|
||||
for param in params:
|
||||
if isinstance(param, dict):
|
||||
kwargs.update(param)
|
||||
else:
|
||||
args += (param,)
|
||||
return (args, kwargs)
|
||||
|
||||
|
||||
def create_obj_instance(cls_type, params=None):
|
||||
"""Create python instance."""
|
||||
if not isinstance(cls_type, type):
|
||||
logger.warning(f"create_obj_instance(), cls_type is not a type, cls_type: {cls_type}")
|
||||
return None
|
||||
|
||||
# Check the type, now only support nn.Cell and Primitive.
|
||||
obj = None
|
||||
if isinstance(cls_type, type):
|
||||
# check the type, now only support nn.Cell and Primitive
|
||||
if issubclass(cls_type, (nn.Cell, ops.Primitive)):
|
||||
if args_tuple is not None:
|
||||
obj = cls_type(*args_tuple)
|
||||
else:
|
||||
obj = cls_type()
|
||||
if issubclass(cls_type, (nn.Cell, ops.Primitive)):
|
||||
# Check arguments, only support *args or **kwargs.
|
||||
if params is None:
|
||||
obj = cls_type()
|
||||
elif isinstance(params, tuple):
|
||||
args, kwargs = _convert_tuple_to_args_kwargs(params)
|
||||
logger.debug(f"create_obj_instance(), args: {args}, kwargs: {kwargs}")
|
||||
if args and kwargs:
|
||||
obj = cls_type(*args, **kwargs)
|
||||
elif args:
|
||||
obj = cls_type(*args)
|
||||
elif kwargs:
|
||||
obj = cls_type(**kwargs)
|
||||
# If invalid parameters.
|
||||
if obj is None:
|
||||
raise ValueError(f"When call 'create_instance', the parameter should be *args or **kwargs, "
|
||||
f"but got {params.__class__.__name__}, params: {params}")
|
||||
return obj
|
||||
|
||||
|
||||
|
|
|
@ -481,7 +481,7 @@ std::vector<DataConverterPtr> GetDataConverters() {
|
|||
} // namespace
|
||||
|
||||
bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, const TypePtr &dtype) {
|
||||
// check parameter valid
|
||||
// Check parameter valid
|
||||
if (data == nullptr) {
|
||||
MS_LOG(ERROR) << "Data is null pointer";
|
||||
return false;
|
||||
|
@ -503,7 +503,7 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
|
|||
return converted != nullptr;
|
||||
}
|
||||
|
||||
// convert data to graph
|
||||
// Convert data to graph
|
||||
FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) {
|
||||
std::vector<std::string> results = data_converter::GetObjKey(obj);
|
||||
std::string obj_id = results[0] + python_mod_get_parse_method;
|
||||
|
@ -565,7 +565,7 @@ std::vector<std::string> GetObjKey(const py::object &obj) {
|
|||
return {py::cast<std::string>(obj_tuple[0]), py::cast<std::string>(obj_tuple[1])};
|
||||
}
|
||||
|
||||
// get obj detail type
|
||||
// Get obj detail type
|
||||
ResolveTypeDef GetObjType(const py::object &obj) {
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
auto obj_type =
|
||||
|
@ -573,7 +573,7 @@ ResolveTypeDef GetObjType(const py::object &obj) {
|
|||
return obj_type;
|
||||
}
|
||||
|
||||
// get class instance detail type
|
||||
// Get class instance detail type.
|
||||
ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) {
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
auto class_type =
|
||||
|
@ -581,26 +581,27 @@ ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) {
|
|||
return class_type;
|
||||
}
|
||||
|
||||
// check the object is Cell Instance
|
||||
// Check the object is Cell Instance.
|
||||
bool IsCellInstance(const py::object &obj) {
|
||||
auto class_type = GetClassInstanceType(obj);
|
||||
bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL);
|
||||
return isCell;
|
||||
}
|
||||
|
||||
// create the python class instance
|
||||
py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms) {
|
||||
// Create the python class instance.
|
||||
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs) {
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
return params.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type)
|
||||
: python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params);
|
||||
// `args_kwargs` maybe a tuple(*args), tuple(**kwargs), or tuple(*args, **kwargs).
|
||||
return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type)
|
||||
: python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, args_kwargs);
|
||||
}
|
||||
|
||||
// Generate an appropriate name and set to graph debuginfo
|
||||
// character <> can not used in the dot file, so change to another symbol
|
||||
// Generate an appropriate name and set to graph debuginfo,
|
||||
// character <> can not used in the dot file, so change to another symbol.
|
||||
void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(func_graph->debug_info());
|
||||
// set detail name info of function
|
||||
// Set detail name info of function
|
||||
std::ostringstream oss;
|
||||
for (size_t i = 0; i < name.size(); i++) {
|
||||
if (name[i] == '<') {
|
||||
|
@ -629,7 +630,7 @@ void ClearObjectCache() {
|
|||
|
||||
static std::unordered_map<std::string, ClassPtr> g_dataClassToClass = {};
|
||||
|
||||
// parse dataclass to mindspore Class type
|
||||
// Parse dataclass to mindspore Class type
|
||||
ClassPtr ParseDataClass(const py::object &cls_obj) {
|
||||
std::string cls_name = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__name__"));
|
||||
std::string cls_module = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__module__"));
|
||||
|
|
|
@ -44,7 +44,7 @@ ResolveTypeDef GetObjType(const py::object &obj);
|
|||
ClassInstanceTypeDef GetClassInstanceType(const py::object &obj);
|
||||
|
||||
bool IsCellInstance(const py::object &obj);
|
||||
py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms);
|
||||
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs);
|
||||
void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name);
|
||||
ValuePtr PyDataToValue(const py::object &obj);
|
||||
void ClearObjectCache();
|
||||
|
|
|
@ -1107,7 +1107,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
|
||||
}
|
||||
|
||||
// get the type parameter
|
||||
// Get the type parameter.
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
TypePtr type = args_spec_list[0]->GetTypeTrack();
|
||||
if (type->type_id() != kMetaTypeTypeType) {
|
||||
|
@ -1131,17 +1131,17 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
auto class_type = type_obj->obj();
|
||||
MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << ".";
|
||||
|
||||
// get the create instance obj's parameters
|
||||
pybind11::tuple params = GetParameters(args_spec_list);
|
||||
// Get the create instance obj's parameters, `params` may contain tuple(args, kwargs).
|
||||
py::tuple params = GetParameters(args_spec_list);
|
||||
|
||||
// create class instance
|
||||
// Create class instance.
|
||||
auto obj = parse::data_converter::CreatePythonObject(class_type, params);
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
MS_LOG(EXCEPTION) << "Create python object" << py::str(class_type)
|
||||
<< " failed, only support create Cell or Primitive object.";
|
||||
}
|
||||
|
||||
// process the object
|
||||
// Process the object.
|
||||
ValuePtr converted_ret = nullptr;
|
||||
bool converted = parse::ConvertData(obj, &converted_ret, true);
|
||||
if (!converted) {
|
||||
|
|
|
@ -167,6 +167,14 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
|
|||
} else if (value->isa<AnyValue>() || value->isa<None>() || value->isa<Monad>() || value->isa<FuncGraph>()) {
|
||||
// FuncGraph is not used in the backend, return None
|
||||
ret = py::none();
|
||||
} else if (value->isa<KeywordArg>()) {
|
||||
auto abs_keyword_arg = value->ToAbstract()->cast<abstract::AbstractKeywordArgPtr>();
|
||||
auto key = abs_keyword_arg->get_key();
|
||||
auto val = abs_keyword_arg->get_arg()->BuildValue();
|
||||
auto py_value = ValuePtrToPyData(val);
|
||||
auto kwargs = py::kwargs();
|
||||
kwargs[key.c_str()] = py_value;
|
||||
ret = kwargs;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ import numpy as np
|
|||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common import Tensor, Parameter
|
||||
from mindspore.ops import operations as P
|
||||
from ...ut_filter import non_graph_engine
|
||||
|
||||
|
@ -48,7 +48,7 @@ class Net(nn.Cell):
|
|||
return x
|
||||
|
||||
|
||||
# Test: creat CELL OR Primitive instance on construct
|
||||
# Test: Create Cell OR Primitive instance on construct
|
||||
@non_graph_engine
|
||||
def test_create_cell_object_on_construct():
|
||||
""" test_create_cell_object_on_construct """
|
||||
|
@ -65,7 +65,7 @@ def test_create_cell_object_on_construct():
|
|||
log.debug("finished test_create_object_on_construct")
|
||||
|
||||
|
||||
# Test: creat CELL OR Primitive instance on construct
|
||||
# Test: Create Cell OR Primitive instance on construct
|
||||
class Net1(nn.Cell):
|
||||
""" Net1 definition """
|
||||
|
||||
|
@ -92,7 +92,7 @@ def test_create_primitive_object_on_construct():
|
|||
log.debug("finished test_create_object_on_construct")
|
||||
|
||||
|
||||
# Test: creat CELL OR Primitive instance on construct use many parameter
|
||||
# Test: Create Cell OR Primitive instance on construct use many parameter
|
||||
class NetM(nn.Cell):
|
||||
""" NetM definition """
|
||||
|
||||
|
@ -120,7 +120,7 @@ class NetC(nn.Cell):
|
|||
return x
|
||||
|
||||
|
||||
# Test: creat CELL OR Primitive instance on construct
|
||||
# Test: Create Cell OR Primitive instance on construct
|
||||
@non_graph_engine
|
||||
def test_create_cell_object_on_construct_use_many_parameter():
|
||||
""" test_create_cell_object_on_construct_use_many_parameter """
|
||||
|
@ -135,3 +135,60 @@ def test_create_cell_object_on_construct_use_many_parameter():
|
|||
print(np1)
|
||||
print(out_me1)
|
||||
log.debug("finished test_create_object_on_construct")
|
||||
|
||||
|
||||
class NetD(nn.Cell):
|
||||
""" NetD definition """
|
||||
|
||||
def __init__(self):
|
||||
super(NetD, self).__init__()
|
||||
|
||||
def construct(self, x, y):
|
||||
concat = P.Concat(axis=1)
|
||||
return concat((x, y))
|
||||
|
||||
|
||||
# Test: Create Cell OR Primitive instance on construct
|
||||
@non_graph_engine
|
||||
def test_create_primitive_object_on_construct_use_kwargs():
|
||||
""" test_create_primitive_object_on_construct_use_kwargs """
|
||||
log.debug("begin test_create_primitive_object_on_construct_use_kwargs")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
|
||||
y = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
|
||||
net = NetD()
|
||||
net(x, y)
|
||||
log.debug("finished test_create_primitive_object_on_construct_use_kwargs")
|
||||
|
||||
|
||||
class NetE(nn.Cell):
|
||||
""" NetE definition """
|
||||
|
||||
def __init__(self):
|
||||
super(NetE, self).__init__()
|
||||
self.w = Parameter(Tensor(np.ones([16, 16, 3, 3]).astype(np.float32)), name='w')
|
||||
|
||||
def construct(self, x):
|
||||
out_channel = 16
|
||||
kernel_size = 3
|
||||
conv2d = P.Conv2D(out_channel,
|
||||
kernel_size,
|
||||
1,
|
||||
pad_mode='valid',
|
||||
pad=0,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group=1)
|
||||
return conv2d(x, self.w)
|
||||
|
||||
|
||||
# Test: Create Cell OR Primitive instance on construct
|
||||
@non_graph_engine
|
||||
def test_create_primitive_object_on_construct_use_args_and_kwargs():
|
||||
""" test_create_primitive_object_on_construct_use_args_and_kwargs """
|
||||
log.debug("begin test_create_primitive_object_on_construct_use_args_and_kwargs")
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
inputs = Tensor(np.ones([1, 16, 16, 16]).astype(np.float32))
|
||||
net = NetE()
|
||||
net(inputs)
|
||||
log.debug("finished test_create_primitive_object_on_construct_use_args_and_kwargs")
|
||||
|
|
Loading…
Reference in New Issue