Optimize the error log description for CreateInstance.
Support ms_function for user defined class methods, not only for cell methods.
This commit is contained in:
parent
f704ba6c32
commit
8d48c81857
|
@ -404,7 +404,7 @@ bool ParseAction(const ResourcePtr &resource) {
|
|||
}
|
||||
|
||||
FuncGraphPtr top_graph = nullptr;
|
||||
if (py::isinstance<Cell>(input)) {
|
||||
if (py::hasattr(input, parse::PYTHON_PARSE_METHOD)) {
|
||||
top_graph = parse::MakeTopGraph(input, converted_ret);
|
||||
} else if (converted_ret->isa<FuncGraph>()) {
|
||||
top_graph = converted_ret->cast<FuncGraphPtr>();
|
||||
|
|
|
@ -252,6 +252,15 @@ ValuePtr ConvertModuleNameSpace(const py::object &obj) {
|
|||
ValuePtr ConvertMsClass(const py::object &obj) {
|
||||
MS_LOG(DEBUG) << "Converting ms class";
|
||||
// Convert class instance decorated with ms_class.
|
||||
if (py::hasattr(obj, PYTHON_PARSE_METHOD)) {
|
||||
MS_LOG(DEBUG) << "Convert obj to func graph.";
|
||||
FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Parse resolve function error.";
|
||||
return nullptr;
|
||||
}
|
||||
return func_graph;
|
||||
}
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
py::object name = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MS_CLASS_NAME, obj);
|
||||
auto cls_name = py::cast<std::string>(name);
|
||||
|
@ -382,7 +391,8 @@ ValuePtr ConvertOtherObj(const py::object &obj, bool forbid_reuse = false) {
|
|||
// desc has format "<class xxxx>", strip the '<' and '>' by offset 1.
|
||||
return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
|
||||
}
|
||||
if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) {
|
||||
if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD ||
|
||||
(obj_type == RESOLVE_TYPE_CLASS_INSTANCE && py::hasattr(obj, PYTHON_PARSE_METHOD))) {
|
||||
MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type;
|
||||
FuncGraphPtr func_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_PARSE_METHOD, forbid_reuse);
|
||||
if (func_graph == nullptr) {
|
||||
|
@ -392,13 +402,7 @@ ValuePtr ConvertOtherObj(const py::object &obj, bool forbid_reuse = false) {
|
|||
return func_graph;
|
||||
}
|
||||
if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) {
|
||||
// Create the namespace for common class instance
|
||||
// When the obj is Cell, default parse the 'construct'
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
|
||||
auto res = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
|
||||
MS_LOG(DEBUG) << "name_space: " << res->ToString();
|
||||
return res;
|
||||
MS_LOG(EXCEPTION) << "Fail to convert class instance: " << py::str(obj);
|
||||
}
|
||||
// Start RESOLVE_TYPE_INVALID...
|
||||
// The fallback feature is enabled in default.
|
||||
|
|
|
@ -2890,7 +2890,7 @@ bool ParseFunctionAst::InitParseAstInfo(const std::string &python_mod_get_parse_
|
|||
|
||||
// Call python parse, get the parser fn
|
||||
module_ = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
py::object parse_method = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_PARSE_METHOD);
|
||||
py::object parse_method = python_adapter::GetPyObjAttr(obj_, PYTHON_PARSE_METHOD);
|
||||
|
||||
// Get the obj type
|
||||
auto type = data_converter::GetObjType(obj_);
|
||||
|
@ -2909,7 +2909,7 @@ bool ParseFunctionAst::InitParseAstInfo(const std::string &python_mod_get_parse_
|
|||
function_ = obj_;
|
||||
obj_ = method_object;
|
||||
} else if (type == RESOLVE_TYPE_CLASS_INSTANCE) {
|
||||
// obj is class instance, get the method to parse.
|
||||
// 'obj' is class instance, get the method to parse.
|
||||
function_ = python_adapter::CallPyModFn(module_, python_mod_get_parse_method, obj_, parse_method);
|
||||
if (py::isinstance<py::none>(function_)) {
|
||||
MS_LOG(ERROR) << "Get obj method function failed.";
|
||||
|
@ -2923,7 +2923,7 @@ bool ParseFunctionAst::InitParseAstInfo(const std::string &python_mod_get_parse_
|
|||
return false;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Parse obj is invalid, only can parse function and obj, type = " << type;
|
||||
MS_LOG(WARNING) << "Parse obj is invalid, only can parse function and obj, type: " << type;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -3015,11 +3015,11 @@ bool UpdateFuncGraphFlags(const py::object &obj, const FuncGraphPtr &func_graph)
|
|||
|
||||
SetMixedPrecisionFlag(obj, func_graph);
|
||||
|
||||
if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) {
|
||||
if (!py::hasattr(obj, PYTHON_FUNC_GRAPH_FLAGS)) {
|
||||
MS_LOG(DEBUG) << "No flags";
|
||||
return true;
|
||||
}
|
||||
py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG);
|
||||
py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_FUNC_GRAPH_FLAGS);
|
||||
for (auto &item : flags) {
|
||||
if (!py::isinstance<py::str>(item.first)) {
|
||||
MS_LOG(ERROR) << "Type error in flags dict convert";
|
||||
|
|
|
@ -143,8 +143,8 @@ const char PYTHON_GET_METHOD_LEN[] = "__len__";
|
|||
const char PYTHON_GET_METHOD_SELF_CLASS[] = "__self__";
|
||||
const char PYTHON_GET_OBJ_DESC[] = "__str__";
|
||||
|
||||
const char PYTHON_EXTERN_PARSE_METHOD[] = "__parse_method__";
|
||||
const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags";
|
||||
const char PYTHON_PARSE_METHOD[] = "__parse_method__";
|
||||
const char PYTHON_FUNC_GRAPH_FLAGS[] = "_func_graph_flags";
|
||||
|
||||
// Define the parse constant.
|
||||
const int64_t MAX_COMPARISON_OPS_SUPPORTED = 1;
|
||||
|
@ -160,14 +160,15 @@ const char RESOLVE_NAMESPACE_NAME_MODULE[] = "Module"; // For Module
|
|||
|
||||
// Define Resolve type.
|
||||
enum ResolveTypeDef : int64_t {
|
||||
RESOLVE_TYPE_NONE = 0, // Resolve None
|
||||
RESOLVE_TYPE_FUNCTION = 1, // Resolve function
|
||||
RESOLVE_TYPE_METHOD = 2, // Resolve class method
|
||||
RESOLVE_TYPE_CLASS_TYPE = 3, // Resolve class type
|
||||
RESOLVE_TYPE_CLASS_INSTANCE = 4, // Resolve the class instance of common class
|
||||
RESOLVE_TYPE_NUMPY_INT_NUMBER = 5, // Resolve numpy number int type
|
||||
RESOLVE_TYPE_NUMPY_FLOAT_NUMBER = 6, // Resolve numpy number float type
|
||||
RESOLVE_TYPE_INVALID = 0xFF // Resolve invalid
|
||||
RESOLVE_TYPE_NONE = 0, // Resolve None.
|
||||
RESOLVE_TYPE_FUNCTION = 1, // Resolve function.
|
||||
RESOLVE_TYPE_METHOD = 2, // Resolve class method.
|
||||
RESOLVE_TYPE_CLASS_TYPE = 3, // Resolve class type.
|
||||
RESOLVE_TYPE_CLASS_INSTANCE = 4, // Resolve the class instance of common class.
|
||||
RESOLVE_TYPE_NAMESPACE_INSTANCE = 5, // Resolve the namespace instance.
|
||||
RESOLVE_TYPE_NUMPY_INT_NUMBER = 6, // Resolve numpy number int type.
|
||||
RESOLVE_TYPE_NUMPY_FLOAT_NUMBER = 7, // Resolve numpy number float type.
|
||||
RESOLVE_TYPE_INVALID = 0xFF // Resolve invalid.
|
||||
};
|
||||
|
||||
// Define the class instance detail type When the type is RESOLVE_TYPE_CLASS_INSTANCE.
|
||||
|
|
|
@ -253,12 +253,20 @@ void SetValueMutable(const abstract::AbstractBasePtr &abs) {
|
|||
|
||||
void CheckArgsValid(const py::object &source_obj, const py::tuple &args) {
|
||||
std::string obj_desc;
|
||||
if (py::isinstance<Cell>(source_obj)) {
|
||||
if (py::hasattr(source_obj, parse::PYTHON_PARSE_METHOD)) {
|
||||
auto cell_class_name = source_obj.attr("__class__").attr("__name__");
|
||||
obj_desc = "'" + py::cast<std::string>(cell_class_name) + ".construct'";
|
||||
auto ms_function_name = source_obj.attr(parse::PYTHON_PARSE_METHOD);
|
||||
obj_desc = "'" + py::cast<std::string>(cell_class_name) + "." + py::cast<std::string>(ms_function_name) + "'";
|
||||
} else {
|
||||
auto ms_function_name = source_obj.attr("__name__");
|
||||
obj_desc = "'" + py::cast<std::string>(ms_function_name) + "'";
|
||||
if (py::hasattr(source_obj, "__name__")) {
|
||||
auto ms_function_name = source_obj.attr("__name__");
|
||||
obj_desc = "'" + py::cast<std::string>(ms_function_name) + "'";
|
||||
} else if (py::isinstance<Cell>(source_obj)) {
|
||||
auto cell_class_name = source_obj.attr("__class__").attr("__name__");
|
||||
obj_desc = "'" + py::cast<std::string>(cell_class_name) + ".construct'";
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "The source object is invalid: " << py::str(source_obj);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
if (!CheckArgValid(args[i])) {
|
||||
|
|
|
@ -1761,9 +1761,9 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
auto arg = args_spec_list[i + 1];
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (IsContainUndetermined(arg)) {
|
||||
MS_EXCEPTION(TypeError) << "The " << i << "th input of method __init__ for "
|
||||
MS_EXCEPTION(TypeError) << "The " << i << "th initializing input to create instance for "
|
||||
<< args_spec_list[0]->BuildValue()->ToString()
|
||||
<< " should be a scalar but got:" << arg->ToString();
|
||||
<< " should be a constant, but got: " << arg->ToString();
|
||||
}
|
||||
// Because the Tensor's AbstractTensor can't get value from GetValueTrack.
|
||||
ValuePtr param_value = arg->BuildValue();
|
||||
|
|
|
@ -37,7 +37,7 @@ from mindspore import ops
|
|||
from mindspore.common.api import _MindsporeFunctionExecutor, _convert_python_data
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace, ClassAttrNamespace
|
||||
from .namespace import Namespace, CellNamespace, ClosureNamespace, ClassMemberNamespace, ClassAttrNamespace
|
||||
from .resources import parse_object_map, ops_symbol_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
|
||||
from .jit_fallback_modules import jit_fallback_third_party_modules_whitelist
|
||||
|
||||
|
@ -46,14 +46,15 @@ RET_SUCCESS = 0
|
|||
RET_FAILURE = 0xFF
|
||||
|
||||
# Define resolve type
|
||||
RESOLVE_TYPE_NONE = 0 # Resolve None
|
||||
RESOLVE_TYPE_FUNCTION = 1 # Resolve function
|
||||
RESOLVE_TYPE_METHOD = 2 # Resolve class method
|
||||
RESOLVE_TYPE_CLASS_TYPE = 3 # Resolve class type
|
||||
RESOLVE_TYPE_CLASS_INSTANCE = 4 # Resolve the class instance of common class
|
||||
RESOLVE_TYPE_NUMPY_INT_NUMBER = 5 # Resolve numpy int number
|
||||
RESOLVE_TYPE_NUMPY_FLOAT_NUMBER = 6 # Resolve numpy float number
|
||||
RESOLVE_TYPE_INVALID = 0xFF
|
||||
RESOLVE_TYPE_NONE = 0 # Resolve None.
|
||||
RESOLVE_TYPE_FUNCTION = 1 # Resolve function.
|
||||
RESOLVE_TYPE_METHOD = 2 # Resolve class method.
|
||||
RESOLVE_TYPE_CLASS_TYPE = 3 # Resolve class type.
|
||||
RESOLVE_TYPE_CLASS_INSTANCE = 4 # Resolve the class instance of common class.
|
||||
RESOLVE_TYPE_NAMESPACE_INSTANCE = 5 # Resolve the namespace instance.
|
||||
RESOLVE_TYPE_NUMPY_INT_NUMBER = 6 # Resolve numpy int number.
|
||||
RESOLVE_TYPE_NUMPY_FLOAT_NUMBER = 7 # Resolve numpy float number.
|
||||
RESOLVE_TYPE_INVALID = 0xFF # Resolve invalid.
|
||||
|
||||
# Define the class instance detail type
|
||||
# When the type is RESOLVE_TYPE_CLASS_INSTANCE
|
||||
|
@ -125,7 +126,7 @@ def parse_cb(func, parse_method=None):
|
|||
|
||||
def get_parse_method_of_class(obj, parse_method=None):
|
||||
"""
|
||||
Het parse method of class.
|
||||
Get parse method of class.
|
||||
|
||||
Args:
|
||||
obj(Object): Instance of class.
|
||||
|
@ -216,13 +217,13 @@ def resolve_symbol(namespace, symbol):
|
|||
try:
|
||||
resolve_ = namespace[symbol]
|
||||
|
||||
# list and dict is not hashable ,it can not be key for the map, just return the result
|
||||
# The list and dict is not hashable, it can not be key for the map, just return the result
|
||||
if isinstance(resolve_, (tuple, list, dict)):
|
||||
return resolve_
|
||||
if getattr(resolve_, "__hash__") is None:
|
||||
return resolve_
|
||||
|
||||
# Raise a proper error if not using Fallback feature.
|
||||
# Raise a proper error if not using JIT Fallback feature.
|
||||
if support_fallback_ == '0':
|
||||
# Raise NotImplementedError when parsing the numpy methods, but not the numpy constant.
|
||||
if namespace.name == "numpy" and \
|
||||
|
@ -317,6 +318,8 @@ def get_obj_type(obj):
|
|||
obj_type = RESOLVE_TYPE_METHOD
|
||||
elif isinstance(obj, type):
|
||||
obj_type = RESOLVE_TYPE_CLASS_TYPE
|
||||
elif isinstance(obj, Namespace):
|
||||
obj_type = RESOLVE_TYPE_NAMESPACE_INSTANCE
|
||||
elif _is_class_instance(obj):
|
||||
obj_type = RESOLVE_TYPE_CLASS_INSTANCE
|
||||
elif _is_numpy_int_number(obj):
|
||||
|
@ -361,7 +364,7 @@ def _is_ms_class(obj):
|
|||
|
||||
def _is_class_instance(obj):
|
||||
"""Confirm the obj is class instance."""
|
||||
return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_ms_class(obj)
|
||||
return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_ms_class(obj) or hasattr(obj, '__parse_method__')
|
||||
|
||||
|
||||
def _is_numpy_int_number(obj):
|
||||
|
@ -504,9 +507,6 @@ def is_class_type(cls):
|
|||
|
||||
def get_ms_class_name(cls):
|
||||
"""Get the name of the class instance decorated by ms_class."""
|
||||
# Check if cls is nn.Cell.
|
||||
if isinstance(cls, nn.Cell):
|
||||
raise TypeError(f"ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
|
||||
if isinstance(cls, type):
|
||||
return cls.__name__
|
||||
return cls.__class__.__name__
|
||||
|
|
|
@ -268,14 +268,17 @@ class _MindsporeFunctionExecutor:
|
|||
generate_name = generate_name + ".grad"
|
||||
if is_pynative_parallel():
|
||||
generate_name = generate_name[:generate_name.rfind(str(id(self.fn)))] + str(id(self.shard_parent_obj))
|
||||
self.fn.__parse_method__ = method_name
|
||||
|
||||
# Add key with obj
|
||||
if self.obj is not None:
|
||||
if self.obj.__module__ != self.fn.__module__:
|
||||
logger.info(f'`obj` module not equal to `fn` module: {self.obj.__module__}, {self.fn.__module__}')
|
||||
self.obj.__parse_method__ = method_name
|
||||
generate_name = generate_name + '.' + str(self.obj.create_time) + '.' + str(id(self.obj))
|
||||
if isinstance(self.obj, ms.nn.Cell):
|
||||
generate_name = generate_name + '.' + str(self.obj.create_time)
|
||||
else:
|
||||
generate_name = generate_name + '.' + str(self._create_time)
|
||||
generate_name = generate_name + '.' + str(id(self.obj))
|
||||
else:
|
||||
# Different instance of same class may use same memory(means same obj_id) at diff times.
|
||||
# To avoid unexpected phase matched, add create_time to generate_name.
|
||||
|
@ -298,7 +301,8 @@ class _MindsporeFunctionExecutor:
|
|||
if self.obj is None:
|
||||
is_compile = self._graph_executor.compile(self.fn, compile_args, phase, True)
|
||||
else:
|
||||
self._graph_executor.set_weights_values(self.obj.parameters_dict())
|
||||
if isinstance(self.obj, ms.nn.Cell):
|
||||
self._graph_executor.set_weights_values(self.obj.parameters_dict())
|
||||
is_compile = self._graph_executor.compile(self.obj, compile_args, phase, True)
|
||||
|
||||
if is_pynative_parallel() and self.fn.__name__ == _PYNATIVE_PARRALLEL_FUNC_NAME:
|
||||
|
@ -570,6 +574,9 @@ def ms_class(cls):
|
|||
# Check if cls is of type class.
|
||||
if not inspect.isclass(cls):
|
||||
raise TypeError(f'Decorator ms_class can only be used for class type, but got {cls}.')
|
||||
# Check if cls is nn.Cell.
|
||||
if issubclass(cls, ms.nn.Cell):
|
||||
raise TypeError(f"Decorator ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
|
||||
logger.info(f'Found ms_class: {cls}.')
|
||||
setattr(cls, '__ms_class__', True)
|
||||
return cls
|
||||
|
|
|
@ -83,7 +83,7 @@ class Cell(Cell_):
|
|||
"""
|
||||
|
||||
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
|
||||
'_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run',
|
||||
'_construct_inputs_num', '_create_time', '_func_graph_flags', '_parallel_inputs_run',
|
||||
'_parameter_layout_dict', '_params_list', '_tensor_list', '_phase', '_auto_parallel_mode',
|
||||
'_forward_pre_hook', '_forward_hook', '_enable_forward_pre_hook', '_enable_forward_hook',
|
||||
'_bprop_debug', '_enable_backward_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
|
||||
|
@ -1473,9 +1473,9 @@ class Cell(Cell_):
|
|||
flags (dict): Network configuration information, currently it is used for the binding of network and
|
||||
dataset. Users can also customize network attributes by this parameter. Default: None.
|
||||
"""
|
||||
if not hasattr(self, "_mindspore_flags"):
|
||||
self._mindspore_flags = {}
|
||||
self._mindspore_flags.update({**flags})
|
||||
if not hasattr(self, "_func_graph_flags"):
|
||||
self._func_graph_flags = {}
|
||||
self._func_graph_flags.update({**flags})
|
||||
self.__dict__.update({**flags})
|
||||
self._add_mixed_precision_flag(**flags)
|
||||
return self
|
||||
|
@ -1502,9 +1502,9 @@ class Cell(Cell_):
|
|||
"""
|
||||
Get the self_defined attributes of the cell, which can be added by `add_flags` method.
|
||||
"""
|
||||
if not hasattr(self, "_mindspore_flags"):
|
||||
self._mindspore_flags = {}
|
||||
return self._mindspore_flags
|
||||
if not hasattr(self, "_func_graph_flags"):
|
||||
self._func_graph_flags = {}
|
||||
return self._func_graph_flags
|
||||
|
||||
def _set_mixed_precision_type_recursive(self, mixed_type):
|
||||
"""Set mixed precision type to each cell"""
|
||||
|
|
|
@ -50,15 +50,15 @@ def add_flags(fn=None, **flags):
|
|||
Examples:
|
||||
>>> net = Net();
|
||||
>>> net = add_flags(net, predit=True)
|
||||
>>> print(hasattr(net, '_mindspore_flags'))
|
||||
>>> print(hasattr(net, '_func_graph_flags'))
|
||||
True
|
||||
"""
|
||||
def deco(fn):
|
||||
# need set the attr and access on c++
|
||||
if not hasattr(fn, "_mindspore_flags"):
|
||||
fn._mindspore_flags = {}
|
||||
if not hasattr(fn, "_func_graph_flags"):
|
||||
fn._func_graph_flags = {}
|
||||
|
||||
fn._mindspore_flags.update({**flags})
|
||||
fn._func_graph_flags.update({**flags})
|
||||
return fn
|
||||
ret = deco
|
||||
if fn is not None:
|
||||
|
@ -84,13 +84,13 @@ def core(fn=None, **flags):
|
|||
Examples:
|
||||
>>> net = Net()
|
||||
>>> net = core(net, predit=True)
|
||||
>>> print(hasattr(net, '_mindspore_flags'))
|
||||
>>> print(hasattr(net, '_func_graph_flags'))
|
||||
True
|
||||
"""
|
||||
# need set the attr and access on c++
|
||||
|
||||
def deco(fn):
|
||||
fn._mindspore_flags = {
|
||||
fn._func_graph_flags = {
|
||||
'core': True,
|
||||
**flags,
|
||||
}
|
||||
|
|
|
@ -85,7 +85,7 @@ class _DataWrapper(nn.Cell):
|
|||
super(_DataWrapper, self).__init__(
|
||||
auto_prefix=False, flags=network.get_flags())
|
||||
# Also copy the flag in `network` construct
|
||||
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
|
||||
flags = getattr(network.__class__.construct, "_func_graph_flags", {})
|
||||
self.info = (dataset_types, dataset_shapes)
|
||||
self.add_flags(**flags)
|
||||
self.get_next = P.GetNext(
|
||||
|
|
|
@ -440,12 +440,12 @@ def test_fallback_raise_error_decorate_cell():
|
|||
Description: Decorator ms_class cannot be used for nn.Cell
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
@ms_class
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x
|
||||
|
||||
x = Tensor(1)
|
||||
net = Net()
|
||||
net(x)
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import mindspore as ms
|
||||
|
||||
|
||||
class Net:
|
||||
@ms.ms_function
|
||||
def test(self, x, y):
|
||||
return ms.ops.mul(x, y)
|
||||
|
||||
|
||||
def test_user_defined_class_with_ms_function():
|
||||
"""
|
||||
Feature: User defined class with ms_function.
|
||||
Description: Test user defined class method with ms_function.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = ms.Tensor([3])
|
||||
y = ms.Tensor([2])
|
||||
net = Net()
|
||||
net.test(x, y)
|
||||
|
||||
|
||||
@ms.ms_class
|
||||
class MsClassNet:
|
||||
@ms.ms_function
|
||||
def test(self, x, y):
|
||||
return ms.ops.mul(x, y)
|
||||
|
||||
|
||||
def test_ms_class_with_ms_function():
|
||||
"""
|
||||
Feature: ms_class with ms_function.
|
||||
Description: Test ms_class method with ms_function.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = ms.Tensor([3])
|
||||
y = ms.Tensor([2])
|
||||
net = MsClassNet()
|
||||
net.test(x, y)
|
||||
|
||||
|
||||
class CellNet(ms.nn.Cell):
|
||||
@ms.ms_function
|
||||
def test(self, x, y):
|
||||
return ms.ops.mul(x, y)
|
||||
|
||||
|
||||
def test_cell_with_ms_function():
|
||||
"""
|
||||
Feature: Cell with ms_function.
|
||||
Description: Test Cell method with ms_function.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = ms.Tensor([3])
|
||||
y = ms.Tensor([2])
|
||||
net = CellNet()
|
||||
net.test(x, y)
|
|
@ -25,9 +25,6 @@ grad_all = C.GradOperation(get_all=True)
|
|||
|
||||
|
||||
class CellBprop(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CellBprop, self).__init__()
|
||||
|
||||
def construct(self, x, y):
|
||||
return 2 * x * x + y * y
|
||||
|
Loading…
Reference in New Issue