Optimize the error log description for CreateInstance.

Support ms_function for user defined class methods, not only for cell methods.
This commit is contained in:
Zhang Qinghua 2022-06-25 11:21:07 +08:00
parent f704ba6c32
commit 8d48c81857
14 changed files with 159 additions and 71 deletions

View File

@ -404,7 +404,7 @@ bool ParseAction(const ResourcePtr &resource) {
}
FuncGraphPtr top_graph = nullptr;
if (py::isinstance<Cell>(input)) {
if (py::hasattr(input, parse::PYTHON_PARSE_METHOD)) {
top_graph = parse::MakeTopGraph(input, converted_ret);
} else if (converted_ret->isa<FuncGraph>()) {
top_graph = converted_ret->cast<FuncGraphPtr>();

View File

@ -252,6 +252,15 @@ ValuePtr ConvertModuleNameSpace(const py::object &obj) {
ValuePtr ConvertMsClass(const py::object &obj) {
MS_LOG(DEBUG) << "Converting ms class";
// Convert class instance decorated with ms_class.
if (py::hasattr(obj, PYTHON_PARSE_METHOD)) {
MS_LOG(DEBUG) << "Convert obj to func graph.";
FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Parse resolve function error.";
return nullptr;
}
return func_graph;
}
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
py::object name = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MS_CLASS_NAME, obj);
auto cls_name = py::cast<std::string>(name);
@ -382,7 +391,8 @@ ValuePtr ConvertOtherObj(const py::object &obj, bool forbid_reuse = false) {
// desc has format "<class xxxx>", strip the '<' and '>' by offset 1.
return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
}
if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) {
if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD ||
(obj_type == RESOLVE_TYPE_CLASS_INSTANCE && py::hasattr(obj, PYTHON_PARSE_METHOD))) {
MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type;
FuncGraphPtr func_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_PARSE_METHOD, forbid_reuse);
if (func_graph == nullptr) {
@ -392,13 +402,7 @@ ValuePtr ConvertOtherObj(const py::object &obj, bool forbid_reuse = false) {
return func_graph;
}
if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) {
// Create the namespace for common class instance
// When the obj is Cell, default parse the 'construct'
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
auto res = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
MS_LOG(DEBUG) << "name_space: " << res->ToString();
return res;
MS_LOG(EXCEPTION) << "Fail to convert class instance: " << py::str(obj);
}
// Start RESOLVE_TYPE_INVALID...
// The fallback feature is enabled in default.

View File

@ -2890,7 +2890,7 @@ bool ParseFunctionAst::InitParseAstInfo(const std::string &python_mod_get_parse_
// Call python parse, get the parser fn
module_ = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
py::object parse_method = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_PARSE_METHOD);
py::object parse_method = python_adapter::GetPyObjAttr(obj_, PYTHON_PARSE_METHOD);
// Get the obj type
auto type = data_converter::GetObjType(obj_);
@ -2909,7 +2909,7 @@ bool ParseFunctionAst::InitParseAstInfo(const std::string &python_mod_get_parse_
function_ = obj_;
obj_ = method_object;
} else if (type == RESOLVE_TYPE_CLASS_INSTANCE) {
// obj is class instance, get the method to parse.
// 'obj' is class instance, get the method to parse.
function_ = python_adapter::CallPyModFn(module_, python_mod_get_parse_method, obj_, parse_method);
if (py::isinstance<py::none>(function_)) {
MS_LOG(ERROR) << "Get obj method function failed.";
@ -2923,7 +2923,7 @@ bool ParseFunctionAst::InitParseAstInfo(const std::string &python_mod_get_parse_
return false;
}
} else {
MS_LOG(WARNING) << "Parse obj is invalid, only can parse function and obj, type = " << type;
MS_LOG(WARNING) << "Parse obj is invalid, only can parse function and obj, type: " << type;
return false;
}
@ -3015,11 +3015,11 @@ bool UpdateFuncGraphFlags(const py::object &obj, const FuncGraphPtr &func_graph)
SetMixedPrecisionFlag(obj, func_graph);
if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) {
if (!py::hasattr(obj, PYTHON_FUNC_GRAPH_FLAGS)) {
MS_LOG(DEBUG) << "No flags";
return true;
}
py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG);
py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_FUNC_GRAPH_FLAGS);
for (auto &item : flags) {
if (!py::isinstance<py::str>(item.first)) {
MS_LOG(ERROR) << "Type error in flags dict convert";

View File

@ -143,8 +143,8 @@ const char PYTHON_GET_METHOD_LEN[] = "__len__";
const char PYTHON_GET_METHOD_SELF_CLASS[] = "__self__";
const char PYTHON_GET_OBJ_DESC[] = "__str__";
const char PYTHON_EXTERN_PARSE_METHOD[] = "__parse_method__";
const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags";
const char PYTHON_PARSE_METHOD[] = "__parse_method__";
const char PYTHON_FUNC_GRAPH_FLAGS[] = "_func_graph_flags";
// Define the parse constant.
const int64_t MAX_COMPARISON_OPS_SUPPORTED = 1;
@ -160,14 +160,15 @@ const char RESOLVE_NAMESPACE_NAME_MODULE[] = "Module"; // For Module
// Define Resolve type.
enum ResolveTypeDef : int64_t {
RESOLVE_TYPE_NONE = 0, // Resolve None
RESOLVE_TYPE_FUNCTION = 1, // Resolve function
RESOLVE_TYPE_METHOD = 2, // Resolve class method
RESOLVE_TYPE_CLASS_TYPE = 3, // Resolve class type
RESOLVE_TYPE_CLASS_INSTANCE = 4, // Resolve the class instance of common class
RESOLVE_TYPE_NUMPY_INT_NUMBER = 5, // Resolve numpy number int type
RESOLVE_TYPE_NUMPY_FLOAT_NUMBER = 6, // Resolve numpy number float type
RESOLVE_TYPE_INVALID = 0xFF // Resolve invalid
RESOLVE_TYPE_NONE = 0, // Resolve None.
RESOLVE_TYPE_FUNCTION = 1, // Resolve function.
RESOLVE_TYPE_METHOD = 2, // Resolve class method.
RESOLVE_TYPE_CLASS_TYPE = 3, // Resolve class type.
RESOLVE_TYPE_CLASS_INSTANCE = 4, // Resolve the class instance of common class.
RESOLVE_TYPE_NAMESPACE_INSTANCE = 5, // Resolve the namespace instance.
RESOLVE_TYPE_NUMPY_INT_NUMBER = 6, // Resolve numpy number int type.
RESOLVE_TYPE_NUMPY_FLOAT_NUMBER = 7, // Resolve numpy number float type.
RESOLVE_TYPE_INVALID = 0xFF // Resolve invalid.
};
// Define the class instance detail type When the type is RESOLVE_TYPE_CLASS_INSTANCE.

View File

@ -253,12 +253,20 @@ void SetValueMutable(const abstract::AbstractBasePtr &abs) {
void CheckArgsValid(const py::object &source_obj, const py::tuple &args) {
std::string obj_desc;
if (py::isinstance<Cell>(source_obj)) {
if (py::hasattr(source_obj, parse::PYTHON_PARSE_METHOD)) {
auto cell_class_name = source_obj.attr("__class__").attr("__name__");
obj_desc = "'" + py::cast<std::string>(cell_class_name) + ".construct'";
auto ms_function_name = source_obj.attr(parse::PYTHON_PARSE_METHOD);
obj_desc = "'" + py::cast<std::string>(cell_class_name) + "." + py::cast<std::string>(ms_function_name) + "'";
} else {
auto ms_function_name = source_obj.attr("__name__");
obj_desc = "'" + py::cast<std::string>(ms_function_name) + "'";
if (py::hasattr(source_obj, "__name__")) {
auto ms_function_name = source_obj.attr("__name__");
obj_desc = "'" + py::cast<std::string>(ms_function_name) + "'";
} else if (py::isinstance<Cell>(source_obj)) {
auto cell_class_name = source_obj.attr("__class__").attr("__name__");
obj_desc = "'" + py::cast<std::string>(cell_class_name) + ".construct'";
} else {
MS_EXCEPTION(TypeError) << "The source object is invalid: " << py::str(source_obj);
}
}
for (size_t i = 0; i < args.size(); i++) {
if (!CheckArgValid(args[i])) {

View File

@ -1761,9 +1761,9 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
auto arg = args_spec_list[i + 1];
MS_EXCEPTION_IF_NULL(arg);
if (IsContainUndetermined(arg)) {
MS_EXCEPTION(TypeError) << "The " << i << "th input of method __init__ for "
MS_EXCEPTION(TypeError) << "The " << i << "th initializing input to create instance for "
<< args_spec_list[0]->BuildValue()->ToString()
<< " should be a scalar but got:" << arg->ToString();
<< " should be a constant, but got: " << arg->ToString();
}
// Because the Tensor's AbstractTensor can't get value from GetValueTrack.
ValuePtr param_value = arg->BuildValue();

View File

@ -37,7 +37,7 @@ from mindspore import ops
from mindspore.common.api import _MindsporeFunctionExecutor, _convert_python_data
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace, ClassAttrNamespace
from .namespace import Namespace, CellNamespace, ClosureNamespace, ClassMemberNamespace, ClassAttrNamespace
from .resources import parse_object_map, ops_symbol_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
from .jit_fallback_modules import jit_fallback_third_party_modules_whitelist
@ -46,14 +46,15 @@ RET_SUCCESS = 0
RET_FAILURE = 0xFF
# Define resolve type
RESOLVE_TYPE_NONE = 0 # Resolve None
RESOLVE_TYPE_FUNCTION = 1 # Resolve function
RESOLVE_TYPE_METHOD = 2 # Resolve class method
RESOLVE_TYPE_CLASS_TYPE = 3 # Resolve class type
RESOLVE_TYPE_CLASS_INSTANCE = 4 # Resolve the class instance of common class
RESOLVE_TYPE_NUMPY_INT_NUMBER = 5 # Resolve numpy int number
RESOLVE_TYPE_NUMPY_FLOAT_NUMBER = 6 # Resolve numpy float number
RESOLVE_TYPE_INVALID = 0xFF
RESOLVE_TYPE_NONE = 0 # Resolve None.
RESOLVE_TYPE_FUNCTION = 1 # Resolve function.
RESOLVE_TYPE_METHOD = 2 # Resolve class method.
RESOLVE_TYPE_CLASS_TYPE = 3 # Resolve class type.
RESOLVE_TYPE_CLASS_INSTANCE = 4 # Resolve the class instance of common class.
RESOLVE_TYPE_NAMESPACE_INSTANCE = 5 # Resolve the namespace instance.
RESOLVE_TYPE_NUMPY_INT_NUMBER = 6 # Resolve numpy int number.
RESOLVE_TYPE_NUMPY_FLOAT_NUMBER = 7 # Resolve numpy float number.
RESOLVE_TYPE_INVALID = 0xFF # Resolve invalid.
# Define the class instance detail type
# When the type is RESOLVE_TYPE_CLASS_INSTANCE
@ -125,7 +126,7 @@ def parse_cb(func, parse_method=None):
def get_parse_method_of_class(obj, parse_method=None):
"""
Het parse method of class.
Get parse method of class.
Args:
obj(Object): Instance of class.
@ -216,13 +217,13 @@ def resolve_symbol(namespace, symbol):
try:
resolve_ = namespace[symbol]
# list and dict is not hashable ,it can not be key for the map, just return the result
# The list and dict is not hashable, it can not be key for the map, just return the result
if isinstance(resolve_, (tuple, list, dict)):
return resolve_
if getattr(resolve_, "__hash__") is None:
return resolve_
# Raise a proper error if not using Fallback feature.
# Raise a proper error if not using JIT Fallback feature.
if support_fallback_ == '0':
# Raise NotImplementedError when parsing the numpy methods, but not the numpy constant.
if namespace.name == "numpy" and \
@ -317,6 +318,8 @@ def get_obj_type(obj):
obj_type = RESOLVE_TYPE_METHOD
elif isinstance(obj, type):
obj_type = RESOLVE_TYPE_CLASS_TYPE
elif isinstance(obj, Namespace):
obj_type = RESOLVE_TYPE_NAMESPACE_INSTANCE
elif _is_class_instance(obj):
obj_type = RESOLVE_TYPE_CLASS_INSTANCE
elif _is_numpy_int_number(obj):
@ -361,7 +364,7 @@ def _is_ms_class(obj):
def _is_class_instance(obj):
"""Confirm the obj is class instance."""
return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_ms_class(obj)
return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_ms_class(obj) or hasattr(obj, '__parse_method__')
def _is_numpy_int_number(obj):
@ -504,9 +507,6 @@ def is_class_type(cls):
def get_ms_class_name(cls):
"""Get the name of the class instance decorated by ms_class."""
# Check if cls is nn.Cell.
if isinstance(cls, nn.Cell):
raise TypeError(f"ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
if isinstance(cls, type):
return cls.__name__
return cls.__class__.__name__

View File

@ -268,14 +268,17 @@ class _MindsporeFunctionExecutor:
generate_name = generate_name + ".grad"
if is_pynative_parallel():
generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
self.fn.__parse_method__ = method_name
# Add key with obj
if self.obj is not None:
if self.obj.__module__ != self.fn.__module__:
logger.info(f'`obj` module not equal to `fn` module: {self.obj.__module__}, {self.fn.__module__}')
self.obj.__parse_method__ = method_name
generate_name = generate_name + '.' + str(self.obj.create_time) + '.' + str(id(self.obj))
if isinstance(self.obj, ms.nn.Cell):
generate_name = generate_name + '.' + str(self.obj.create_time)
else:
generate_name = generate_name + '.' + str(self._create_time)
generate_name = generate_name + '.' + str(id(self.obj))
else:
# Different instance of same class may use same memory(means same obj_id) at diff times.
# To avoid unexpected phase matched, add create_time to generate_name.
@ -298,7 +301,8 @@ class _MindsporeFunctionExecutor:
if self.obj is None:
is_compile = self._graph_executor.compile(self.fn, compile_args, phase, True)
else:
self._graph_executor.set_weights_values(self.obj.parameters_dict())
if isinstance(self.obj, ms.nn.Cell):
self._graph_executor.set_weights_values(self.obj.parameters_dict())
is_compile = self._graph_executor.compile(self.obj, compile_args, phase, True)
if is_pynative_parallel() and self.fn.__name__ == _PYNATIVE_PARRALLEL_FUNC_NAME:
@ -570,6 +574,9 @@ def ms_class(cls):
# Check if cls is of type class.
if not inspect.isclass(cls):
raise TypeError(f'Decorator ms_class can only be used for class type, but got {cls}.')
# Check if cls is nn.Cell.
if issubclass(cls, ms.nn.Cell):
raise TypeError(f"Decorator ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
logger.info(f'Found ms_class: {cls}.')
setattr(cls, '__ms_class__', True)
return cls

View File

@ -83,7 +83,7 @@ class Cell(Cell_):
"""
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
'_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run',
'_construct_inputs_num', '_create_time', '_func_graph_flags', '_parallel_inputs_run',
'_parameter_layout_dict', '_params_list', '_tensor_list', '_phase', '_auto_parallel_mode',
'_forward_pre_hook', '_forward_hook', '_enable_forward_pre_hook', '_enable_forward_hook',
'_bprop_debug', '_enable_backward_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
@ -1473,9 +1473,9 @@ class Cell(Cell_):
flags (dict): Network configuration information, currently it is used for the binding of network and
dataset. Users can also customize network attributes by this parameter. Default: None.
"""
if not hasattr(self, "_mindspore_flags"):
self._mindspore_flags = {}
self._mindspore_flags.update({**flags})
if not hasattr(self, "_func_graph_flags"):
self._func_graph_flags = {}
self._func_graph_flags.update({**flags})
self.__dict__.update({**flags})
self._add_mixed_precision_flag(**flags)
return self
@ -1502,9 +1502,9 @@ class Cell(Cell_):
"""
Get the self_defined attributes of the cell, which can be added by `add_flags` method.
"""
if not hasattr(self, "_mindspore_flags"):
self._mindspore_flags = {}
return self._mindspore_flags
if not hasattr(self, "_func_graph_flags"):
self._func_graph_flags = {}
return self._func_graph_flags
def _set_mixed_precision_type_recursive(self, mixed_type):
"""Set mixed precision type to each cell"""

View File

@ -50,15 +50,15 @@ def add_flags(fn=None, **flags):
Examples:
>>> net = Net();
>>> net = add_flags(net, predit=True)
>>> print(hasattr(net, '_mindspore_flags'))
>>> print(hasattr(net, '_func_graph_flags'))
True
"""
def deco(fn):
# need set the attr and access on c++
if not hasattr(fn, "_mindspore_flags"):
fn._mindspore_flags = {}
if not hasattr(fn, "_func_graph_flags"):
fn._func_graph_flags = {}
fn._mindspore_flags.update({**flags})
fn._func_graph_flags.update({**flags})
return fn
ret = deco
if fn is not None:
@ -84,13 +84,13 @@ def core(fn=None, **flags):
Examples:
>>> net = Net()
>>> net = core(net, predit=True)
>>> print(hasattr(net, '_mindspore_flags'))
>>> print(hasattr(net, '_func_graph_flags'))
True
"""
# need set the attr and access on c++
def deco(fn):
fn._mindspore_flags = {
fn._func_graph_flags = {
'core': True,
**flags,
}

View File

@ -85,7 +85,7 @@ class _DataWrapper(nn.Cell):
super(_DataWrapper, self).__init__(
auto_prefix=False, flags=network.get_flags())
# Also copy the flag in `network` construct
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
flags = getattr(network.__class__.construct, "_func_graph_flags", {})
self.info = (dataset_types, dataset_shapes)
self.add_flags(**flags)
self.get_next = P.GetNext(

View File

@ -440,12 +440,12 @@ def test_fallback_raise_error_decorate_cell():
Description: Decorator ms_class cannot be used for nn.Cell
Expectation: No exception.
"""
@ms_class
class Net(nn.Cell):
def construct(self, x):
return x
with pytest.raises(TypeError):
@ms_class
class Net(nn.Cell):
def construct(self, x):
return x
x = Tensor(1)
net = Net()
net(x)

View File

@ -0,0 +1,71 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore as ms
class Net:
@ms.ms_function
def test(self, x, y):
return ms.ops.mul(x, y)
def test_user_defined_class_with_ms_function():
"""
Feature: User defined class with ms_function.
Description: Test user defined class method with ms_function.
Expectation: No exception.
"""
x = ms.Tensor([3])
y = ms.Tensor([2])
net = Net()
net.test(x, y)
@ms.ms_class
class MsClassNet:
@ms.ms_function
def test(self, x, y):
return ms.ops.mul(x, y)
def test_ms_class_with_ms_function():
"""
Feature: ms_class with ms_function.
Description: Test ms_class method with ms_function.
Expectation: No exception.
"""
x = ms.Tensor([3])
y = ms.Tensor([2])
net = MsClassNet()
net.test(x, y)
class CellNet(ms.nn.Cell):
@ms.ms_function
def test(self, x, y):
return ms.ops.mul(x, y)
def test_cell_with_ms_function():
"""
Feature: Cell with ms_function.
Description: Test Cell method with ms_function.
Expectation: No exception.
"""
x = ms.Tensor([3])
y = ms.Tensor([2])
net = CellNet()
net.test(x, y)

View File

@ -25,9 +25,6 @@ grad_all = C.GradOperation(get_all=True)
class CellBprop(nn.Cell):
def __init__(self):
super(CellBprop, self).__init__()
def construct(self, x, y):
return 2 * x * x + y * y