Support pass args or|and kwargs for OP CreateInstance.

This commit is contained in:
Zhang Qinghua 2021-04-27 17:38:55 +08:00
parent 42fef2b09c
commit cd7f7d40fb
6 changed files with 124 additions and 32 deletions

View File

@ -278,16 +278,42 @@ def _is_dataclass_instance(obj):
return is_dataclass(obj) and not isinstance(obj, type)
def create_obj_instance(cls_type, args_tuple=None):
def _convert_tuple_to_args_kwargs(params):
args = tuple()
kwargs = dict()
for param in params:
if isinstance(param, dict):
kwargs.update(param)
else:
args += (param,)
return (args, kwargs)
def create_obj_instance(cls_type, params=None):
"""Create python instance."""
if not isinstance(cls_type, type):
logger.warning(f"create_obj_instance(), cls_type is not a type, cls_type: {cls_type}")
return None
# Check the type, now only support nn.Cell and Primitive.
obj = None
if isinstance(cls_type, type):
# check the type, now only support nn.Cell and Primitive
if issubclass(cls_type, (nn.Cell, ops.Primitive)):
if args_tuple is not None:
obj = cls_type(*args_tuple)
else:
obj = cls_type()
if issubclass(cls_type, (nn.Cell, ops.Primitive)):
# Check arguments, only support *args or **kwargs.
if params is None:
obj = cls_type()
elif isinstance(params, tuple):
args, kwargs = _convert_tuple_to_args_kwargs(params)
logger.debug(f"create_obj_instance(), args: {args}, kwargs: {kwargs}")
if args and kwargs:
obj = cls_type(*args, **kwargs)
elif args:
obj = cls_type(*args)
elif kwargs:
obj = cls_type(**kwargs)
# If invalid parameters.
if obj is None:
raise ValueError(f"When call 'create_instance', the parameter should be *args or **kwargs, "
f"but got {params.__class__.__name__}, params: {params}")
return obj

View File

@ -481,7 +481,7 @@ std::vector<DataConverterPtr> GetDataConverters() {
} // namespace
bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, const TypePtr &dtype) {
// check parameter valid
// Check parameter valid
if (data == nullptr) {
MS_LOG(ERROR) << "Data is null pointer";
return false;
@ -503,7 +503,7 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
return converted != nullptr;
}
// convert data to graph
// Convert data to graph
FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) {
std::vector<std::string> results = data_converter::GetObjKey(obj);
std::string obj_id = results[0] + python_mod_get_parse_method;
@ -565,7 +565,7 @@ std::vector<std::string> GetObjKey(const py::object &obj) {
return {py::cast<std::string>(obj_tuple[0]), py::cast<std::string>(obj_tuple[1])};
}
// get obj detail type
// Get obj detail type
ResolveTypeDef GetObjType(const py::object &obj) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
auto obj_type =
@ -573,7 +573,7 @@ ResolveTypeDef GetObjType(const py::object &obj) {
return obj_type;
}
// get class instance detail type
// Get class instance detail type.
ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
auto class_type =
@ -581,26 +581,27 @@ ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) {
return class_type;
}
// check the object is Cell Instance
// Check the object is Cell Instance.
bool IsCellInstance(const py::object &obj) {
auto class_type = GetClassInstanceType(obj);
bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL);
return isCell;
}
// create the python class instance
py::object CreatePythonObject(const py::object &type, const py::tuple &params) {
// Create the python class instance.
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
return params.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type)
: python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params);
// `args_kwargs` maybe a tuple(*args), tuple(**kwargs), or tuple(*args, **kwargs).
return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type)
: python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, args_kwargs);
}
// Generate an appropriate name and set to graph debuginfo
// character <> can not used in the dot file, so change to another symbol
// Generate an appropriate name and set to graph debuginfo,
// character <> can not used in the dot file, so change to another symbol.
void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(func_graph->debug_info());
// set detail name info of function
// Set detail name info of function
std::ostringstream oss;
for (size_t i = 0; i < name.size(); i++) {
if (name[i] == '<') {
@ -629,7 +630,7 @@ void ClearObjectCache() {
static std::unordered_map<std::string, ClassPtr> g_dataClassToClass = {};
// parse dataclass to mindspore Class type
// Parse dataclass to mindspore Class type
ClassPtr ParseDataClass(const py::object &cls_obj) {
std::string cls_name = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__name__"));
std::string cls_module = py::cast<std::string>(python_adapter::GetPyObjAttr(cls_obj, "__module__"));

View File

@ -44,7 +44,7 @@ ResolveTypeDef GetObjType(const py::object &obj);
ClassInstanceTypeDef GetClassInstanceType(const py::object &obj);
bool IsCellInstance(const py::object &obj);
py::object CreatePythonObject(const py::object &type, const py::tuple &params);
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs);
void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name);
ValuePtr PyDataToValue(const py::object &obj);
void ClearObjectCache();

View File

@ -1107,7 +1107,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
}
// get the type parameter
// Get the type parameter.
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
TypePtr type = args_spec_list[0]->GetTypeTrack();
if (type->type_id() != kMetaTypeTypeType) {
@ -1131,17 +1131,17 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
auto class_type = type_obj->obj();
MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << ".";
// get the create instance obj's parameters
pybind11::tuple params = GetParameters(args_spec_list);
// Get the create instance obj's parameters, `params` may contain tuple(args, kwargs).
py::tuple params = GetParameters(args_spec_list);
// create class instance
// Create class instance.
auto obj = parse::data_converter::CreatePythonObject(class_type, params);
if (py::isinstance<py::none>(obj)) {
MS_LOG(EXCEPTION) << "Create python object" << py::str(class_type)
<< " failed, only support create Cell or Primitive object.";
}
// process the object
// Process the object.
ValuePtr converted_ret = nullptr;
bool converted = parse::ConvertData(obj, &converted_ret, true);
if (!converted) {

View File

@ -167,6 +167,14 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
} else if (value->isa<AnyValue>() || value->isa<None>() || value->isa<Monad>() || value->isa<FuncGraph>()) {
// FuncGraph is not used in the backend, return None
ret = py::none();
} else if (value->isa<KeywordArg>()) {
auto abs_keyword_arg = value->ToAbstract()->cast<abstract::AbstractKeywordArgPtr>();
auto key = abs_keyword_arg->get_key();
auto val = abs_keyword_arg->get_arg()->BuildValue();
auto py_value = ValuePtrToPyData(val);
auto kwargs = py::kwargs();
kwargs[key.c_str()] = py_value;
ret = kwargs;
} else {
MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData.";
}

View File

@ -27,7 +27,7 @@ import numpy as np
import mindspore.nn as nn
from mindspore import context
from mindspore.common.api import ms_function
from mindspore.common.tensor import Tensor
from mindspore.common import Tensor, Parameter
from mindspore.ops import operations as P
from ...ut_filter import non_graph_engine
@ -48,7 +48,7 @@ class Net(nn.Cell):
return x
# Test: creat CELL OR Primitive instance on construct
# Test: Create Cell OR Primitive instance on construct
@non_graph_engine
def test_create_cell_object_on_construct():
""" test_create_cell_object_on_construct """
@ -65,7 +65,7 @@ def test_create_cell_object_on_construct():
log.debug("finished test_create_object_on_construct")
# Test: creat CELL OR Primitive instance on construct
# Test: Create Cell OR Primitive instance on construct
class Net1(nn.Cell):
""" Net1 definition """
@ -92,7 +92,7 @@ def test_create_primitive_object_on_construct():
log.debug("finished test_create_object_on_construct")
# Test: creat CELL OR Primitive instance on construct use many parameter
# Test: Create Cell OR Primitive instance on construct use many parameter
class NetM(nn.Cell):
""" NetM definition """
@ -120,7 +120,7 @@ class NetC(nn.Cell):
return x
# Test: creat CELL OR Primitive instance on construct
# Test: Create Cell OR Primitive instance on construct
@non_graph_engine
def test_create_cell_object_on_construct_use_many_parameter():
""" test_create_cell_object_on_construct_use_many_parameter """
@ -135,3 +135,60 @@ def test_create_cell_object_on_construct_use_many_parameter():
print(np1)
print(out_me1)
log.debug("finished test_create_object_on_construct")
class NetD(nn.Cell):
""" NetD definition """
def __init__(self):
super(NetD, self).__init__()
def construct(self, x, y):
concat = P.Concat(axis=1)
return concat((x, y))
# Test: Create Cell OR Primitive instance on construct
@non_graph_engine
def test_create_primitive_object_on_construct_use_kwargs():
""" test_create_primitive_object_on_construct_use_kwargs """
log.debug("begin test_create_primitive_object_on_construct_use_kwargs")
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
y = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
net = NetD()
net(x, y)
log.debug("finished test_create_primitive_object_on_construct_use_kwargs")
class NetE(nn.Cell):
""" NetE definition """
def __init__(self):
super(NetE, self).__init__()
self.w = Parameter(Tensor(np.ones([16, 16, 3, 3]).astype(np.float32)), name='w')
def construct(self, x):
out_channel = 16
kernel_size = 3
conv2d = P.Conv2D(out_channel,
kernel_size,
1,
pad_mode='valid',
pad=0,
stride=1,
dilation=1,
group=1)
return conv2d(x, self.w)
# Test: Create Cell OR Primitive instance on construct
@non_graph_engine
def test_create_primitive_object_on_construct_use_args_and_kwargs():
""" test_create_primitive_object_on_construct_use_args_and_kwargs """
log.debug("begin test_create_primitive_object_on_construct_use_args_and_kwargs")
context.set_context(mode=context.GRAPH_MODE)
inputs = Tensor(np.ones([1, 16, 16, 16]).astype(np.float32))
net = NetE()
net(inputs)
log.debug("finished test_create_primitive_object_on_construct_use_args_and_kwargs")