!30855 Support user-defined classes by ms_class decorators
Merge pull request !30855 from huangbingjian/ms_class_dev
This commit is contained in:
commit
3c0e34ada0
|
@ -23,60 +23,65 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
|
||||
// {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
|
||||
// {prim::kPrimGetAttr, namespace, attr}
|
||||
// {prim::kPrimGetAttr, bool, attr}
|
||||
// {prim::kPrimGetAttr, object, attr}
|
||||
// {prim::kPrimResolve, namespace, symbol}
|
||||
AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
PatternNode<AnfNodePtr> getattr_operand, ns_node, sym_node, attr_node, bool_node;
|
||||
auto GetAttrResolveLambda = [&node, &getattr_operand, &attr_node, &optimizer]() -> AnfNodePtr {
|
||||
auto getattr_operand_node = getattr_operand.GetNode(node);
|
||||
auto attr = attr_node.GetNode(node);
|
||||
PatternNode<AnfNodePtr> object, attr, ns_node, sym_node;
|
||||
auto GetAttrLambda = [&node, &object, &attr, &optimizer]() -> AnfNodePtr {
|
||||
auto object_node = object.GetNode(node);
|
||||
auto attr_node = attr.GetNode(node);
|
||||
|
||||
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
|
||||
if (IsPrimitiveCNode(getattr_operand_node, prim::kPrimResolve)) {
|
||||
auto [name_space, symbol] = parse::GetNamespaceAndSymbol(getattr_operand_node);
|
||||
if (IsPrimitiveCNode(object_node, prim::kPrimResolve)) {
|
||||
auto [name_space, symbol] = parse::GetNamespaceAndSymbol(object_node);
|
||||
auto module_name = name_space->module();
|
||||
constexpr std::string_view parse_super_name = "namespace";
|
||||
if (module_name.find(parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER) != std::string::npos &&
|
||||
symbol->symbol() != parse_super_name) {
|
||||
auto obj = parse::GetSymbolObject(name_space, symbol, node);
|
||||
return parse::ResolveCellWithAttr(optimizer->manager(), obj, getattr_operand_node, attr);
|
||||
auto symbol_obj = parse::GetSymbolObject(name_space, symbol, node);
|
||||
return parse::ResolveCellWithAttr(optimizer->manager(), symbol_obj, object_node, attr_node);
|
||||
}
|
||||
}
|
||||
|
||||
// {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
|
||||
auto operand_cnode = getattr_operand_node->cast<CNodePtr>();
|
||||
constexpr size_t getitem_inputs_size = 3;
|
||||
if (operand_cnode != nullptr && operand_cnode->size() == getitem_inputs_size) {
|
||||
constexpr auto prim_index = 0;
|
||||
if (parse::IsGetItemCNode(object_node)) {
|
||||
auto getitem_cnode = object_node->cast<CNodePtr>();
|
||||
constexpr auto resolve_index = 1;
|
||||
constexpr auto index_index = 2;
|
||||
auto prim_node = operand_cnode->input(prim_index);
|
||||
auto resolve_node = operand_cnode->input(resolve_index);
|
||||
auto index_node = operand_cnode->input(index_index);
|
||||
if (!parse::IsResolveNodeWithGetItem(prim_node) || !IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) {
|
||||
return nullptr;
|
||||
auto resolve_node = getitem_cnode->input(resolve_index);
|
||||
auto index_node = getitem_cnode->input(index_index);
|
||||
if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) {
|
||||
auto [name_space, symbol] = parse::GetNamespaceAndSymbol(resolve_node);
|
||||
auto obj = parse::GetObjectFromSequence(name_space, symbol, resolve_node, index_node);
|
||||
if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
|
||||
return parse::ResolveSequenceWithAttr(optimizer->manager(), obj, resolve_node, attr_node, getitem_cnode);
|
||||
}
|
||||
return parse::ResolveCellWithAttr(optimizer->manager(), obj, resolve_node, attr_node);
|
||||
}
|
||||
auto [name_space, symbol] = parse::GetNamespaceAndSymbol(resolve_node);
|
||||
auto obj = parse::GetObjectFromSequence(name_space, symbol, resolve_node, index_node);
|
||||
if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) {
|
||||
return parse::ResolveSequenceWithAttr(optimizer->manager(), obj, resolve_node, attr, operand_cnode);
|
||||
}
|
||||
return parse::ResolveCellWithAttr(optimizer->manager(), obj, resolve_node, attr);
|
||||
}
|
||||
|
||||
// {prim::kPrimGetAttr, namespace, attr}
|
||||
if (IsValueNode<parse::NameSpace>(object_node)) {
|
||||
auto name_space = GetValueNode<parse::NameSpacePtr>(object_node);
|
||||
auto attr_str = GetValue<std::string>(GetValueNode(attr_node));
|
||||
parse::SymbolPtr symbol = std::make_shared<parse::Symbol>(attr_str);
|
||||
return parse::ResolveSymbol(optimizer->manager(), name_space, symbol, node);
|
||||
}
|
||||
|
||||
// {prim::kPrimGetAttr, MsClassObject, attr}
|
||||
if (IsValueNode<parse::MsClassObject>(object_node)) {
|
||||
auto ms_class = GetValueNode<parse::MsClassObjectPtr>(object_node);
|
||||
auto attr_str = GetValue<std::string>(GetValueNode(attr_node));
|
||||
return parse::ResolveMsClassWithAttr(optimizer->manager(), ms_class, attr_str, node);
|
||||
}
|
||||
|
||||
// {prim::kPrimGetAttr, bool, attr}
|
||||
if (IsValueNode<BoolImm>(object_node)) {
|
||||
return object_node;
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
auto GetAttrLambda = [&node, &ns_node, &attr_node, &optimizer]() -> AnfNodePtr {
|
||||
auto name_space = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
|
||||
auto str = GetValue<std::string>(GetValueNode(attr_node.GetNode(node)));
|
||||
parse::SymbolPtr symbol = std::make_shared<parse::Symbol>(str);
|
||||
auto manager = optimizer->manager();
|
||||
return parse::ResolveSymbol(manager, name_space, symbol, node);
|
||||
};
|
||||
|
||||
auto ResolveLambda = [&node, &ns_node, &sym_node, &optimizer]() -> AnfNodePtr {
|
||||
auto name_space = GetValueNode<parse::NameSpacePtr>(ns_node.GetNode(node));
|
||||
auto symbol = GetValueNode<parse::SymbolPtr>(sym_node.GetNode(node));
|
||||
|
@ -84,18 +89,9 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr
|
|||
return parse::ResolveSymbol(manager, name_space, symbol, node);
|
||||
};
|
||||
|
||||
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
|
||||
// {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
|
||||
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetAttr, getattr_operand, attr_node), GetAttrResolveLambda,
|
||||
attr_node.CheckFunc(IsValueNode<StringImm>, node));
|
||||
// {prim::kPrimGetAttr, namespace, attr}
|
||||
MATCH_REPLACE_LAMBDA_IF(
|
||||
node, PPrimitive(prim::kPrimGetAttr, ns_node, attr_node), GetAttrLambda,
|
||||
ns_node.CheckFunc(IsValueNode<parse::NameSpace>, node) && attr_node.CheckFunc(IsValueNode<StringImm>, node));
|
||||
// {prim::kPrimGetAttr, bool, attr}
|
||||
MATCH_REPLACE_IF(
|
||||
node, PPrimitive(prim::kPrimGetAttr, bool_node, attr_node), bool_node,
|
||||
bool_node.CheckFunc(IsValueNode<BoolImm>, node) && attr_node.CheckFunc(IsValueNode<StringImm>, node));
|
||||
// {prim::kPrimGetAttr, object, attr}
|
||||
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetAttr, object, attr), GetAttrLambda,
|
||||
attr.CheckFunc(IsValueNode<StringImm>, node));
|
||||
// {prim::kPrimResolve, namespace, symbol}
|
||||
MATCH_REPLACE_LAMBDA_IF(
|
||||
node, PPrimitive(prim::kPrimResolve, ns_node, sym_node), ResolveLambda,
|
||||
|
|
|
@ -40,6 +40,7 @@ namespace irpass {
|
|||
// {prim::kPrimGetAttr, {prim::kPrimTupleGetItem, {prim::kPrimResolve, namespace, symbol}, index}, attr}
|
||||
// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr}
|
||||
// {prim::kPrimGetAttr, namespace, attr}
|
||||
// {prim::kPrimGetAttr, MsClassObject, attr}
|
||||
// {prim::kPrimGetAttr, bool, attr}
|
||||
// {prim::kPrimResolve, namespace, symbol}
|
||||
class Resolver : public OptimizerCaller {
|
||||
|
|
|
@ -253,6 +253,15 @@ ValuePtr ConvertDataClass(const py::object &obj) {
|
|||
return converted;
|
||||
}
|
||||
|
||||
ValuePtr ConvertMsClass(const py::object &obj) {
|
||||
MS_LOG(DEBUG) << "Converting ms class";
|
||||
// Convert class instance decorated with ms_class.
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
py::object name = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MS_CLASS_NAME, obj);
|
||||
auto cls_name = py::cast<std::string>(name);
|
||||
return std::make_shared<MsClassObject>(obj, cls_name);
|
||||
}
|
||||
|
||||
ValuePtr ConvertPrimitive(const py::object &obj, bool use_signature = false) {
|
||||
MS_LOG(DEBUG) << "Converting primitive object" << use_signature;
|
||||
|
||||
|
@ -502,6 +511,7 @@ static const std::vector<DataConverterPtr> &GetDataConverters() {
|
|||
std::make_shared<ByTypeDataConverter<py::ellipsis>>(kEllipsis),
|
||||
std::make_shared<ByTypeDataConverter<py::module>>(ConvertModuleNameSpace),
|
||||
std::make_shared<ByAttrDataConverter>(PYTHON_DATACLASS_FIELDS, ConvertDataClass),
|
||||
std::make_shared<ByAttrDataConverter>(PYTHON_MS_CLASS, ConvertMsClass),
|
||||
std::make_shared<ByTypeDataConverter<Type>>(ObjCast<TypePtr>),
|
||||
std::make_shared<ByTypeDataConverter<UMonad>>(ObjCast<UMonadPtr>),
|
||||
std::make_shared<ByTypeDataConverter<IOMonad>>(ObjCast<IOMonadPtr>),
|
||||
|
|
|
@ -67,6 +67,8 @@ const char PYTHON_MOD_CREATE_INSTANCE[] = "create_instance";
|
|||
const char PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE[] = "is_supported_create_instance_type";
|
||||
const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes";
|
||||
const char PYTHON_MOD_GET_DATACLASS_METHODS[] = "get_dataclass_methods";
|
||||
const char PYTHON_MOD_GET_MS_CLASS_NAME[] = "get_ms_class_name";
|
||||
const char PYTHON_MOD_GET_MS_CLASS_ATTR[] = "get_ms_class_attr";
|
||||
const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
|
||||
const char PYTHON_MOD_GET_ATTR_NAMESPACE_SYMBOL[] = "get_class_attr_namespace_symbol";
|
||||
const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol";
|
||||
|
|
|
@ -307,7 +307,7 @@ AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, cons
|
|||
AnfNodePtr resolved_node = nullptr;
|
||||
bool success = ResolveObjectToNode(node->func_graph(), obj, &resolved_node);
|
||||
if (!success) {
|
||||
MS_LOG(EXCEPTION) << "Parse Resolve covert failed NodeInfo.";
|
||||
MS_LOG(EXCEPTION) << "Parse Resolve covert failed.";
|
||||
}
|
||||
if (IsValueNode<FuncGraph>(resolved_node)) {
|
||||
auto new_fg = GetValueNode<FuncGraphPtr>(resolved_node);
|
||||
|
@ -465,6 +465,40 @@ bool IsResolveNodeWithGetItem(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool IsGetItemCNode(const AnfNodePtr &node) {
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
constexpr size_t getitem_inputs_size = 3;
|
||||
if (cnode->size() != getitem_inputs_size) {
|
||||
return false;
|
||||
}
|
||||
constexpr auto prim_index = 0;
|
||||
return IsResolveNodeWithGetItem(cnode->input(prim_index));
|
||||
}
|
||||
|
||||
AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const MsClassObjectPtr &ms_class,
|
||||
const std::string &attr, const AnfNodePtr &node) {
|
||||
// Get attribute or method from ms_class obj.
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << "Resolve ms_class obj (" << ms_class->name() << ") with attr " << attr << ".";
|
||||
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
|
||||
|
||||
py::object cls_obj = ms_class->obj();
|
||||
if (!py::hasattr(cls_obj, attr.c_str())) {
|
||||
MS_LOG(EXCEPTION) << ms_class->name() << " has not attribute: " << attr << ".";
|
||||
}
|
||||
|
||||
const std::string fn = PYTHON_MOD_GET_MS_CLASS_ATTR;
|
||||
const std::string module = "mindspore._extends.parse.parser";
|
||||
py::object attr_obj = python_adapter::GetPyFn(module, fn)(cls_obj, attr);
|
||||
AnfNodePtr res_node = ResolveObjectAndAddToManager(manager, attr_obj, node);
|
||||
TraceManager::ClearParseOrResolveDebugInfo();
|
||||
return res_node;
|
||||
}
|
||||
|
||||
namespace {
|
||||
opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) {
|
||||
// For resolve and getattr primitive.
|
||||
|
|
|
@ -131,6 +131,18 @@ class InterpretedObject final : public PyObjectWrapper {
|
|||
};
|
||||
using InterpretedObjectPtr = std::shared_ptr<InterpretedObject>;
|
||||
|
||||
class MsClassObject final : public PyObjectWrapper {
|
||||
public:
|
||||
explicit MsClassObject(const py::object &obj, const std::string &name = "ms class")
|
||||
: PyObjectWrapper(obj, "MsClassObject: \'" + name + "\'") {}
|
||||
~MsClassObject() override = default;
|
||||
MS_DECLARE_PARENT(MsClassObject, PyObjectWrapper);
|
||||
abstract::AbstractBasePtr ToAbstract() override {
|
||||
return std::make_shared<abstract::AbstractScalar>(shared_from_base<MsClassObject>(), std::make_shared<External>());
|
||||
}
|
||||
};
|
||||
using MsClassObjectPtr = std::shared_ptr<MsClassObject>;
|
||||
|
||||
// ClassObject class wrappers dataclass
|
||||
class ClassObject final : public PyObjectWrapper {
|
||||
public:
|
||||
|
@ -168,8 +180,11 @@ AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::obj
|
|||
AnfNodePtr ResolveSequenceWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj,
|
||||
const AnfNodePtr &resolve_node, const AnfNodePtr &attr,
|
||||
const CNodePtr &operand_cnode);
|
||||
// Check if node is resolve node with getitem.
|
||||
bool IsResolveNodeWithGetItem(const AnfNodePtr &node);
|
||||
AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const MsClassObjectPtr &ms_class,
|
||||
const std::string &attr, const AnfNodePtr &node);
|
||||
|
||||
// Check if node is cnode with getitem.
|
||||
bool IsGetItemCNode(const AnfNodePtr &node);
|
||||
|
||||
// Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager().
|
||||
bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true);
|
||||
|
|
|
@ -19,5 +19,6 @@ namespace mindspore {
|
|||
const char PYTHON_PRIMITIVE_FLAG[] = "__primitive_flag__";
|
||||
const char PYTHON_CELL_AS_LIST[] = "__cell_as_list__";
|
||||
const char PYTHON_DATACLASS_FIELDS[] = "__dataclass_fields__";
|
||||
const char PYTHON_MS_CLASS[] = "__ms_class__";
|
||||
const char PYTHON_CLASS_MEMBER_NAMESPACE[] = "__class_member_namespace__";
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,6 +22,7 @@ namespace mindspore {
|
|||
extern const char PYTHON_PRIMITIVE_FLAG[];
|
||||
extern const char PYTHON_CELL_AS_LIST[];
|
||||
extern const char PYTHON_DATACLASS_FIELDS[];
|
||||
extern const char PYTHON_MS_CLASS[];
|
||||
extern const char PYTHON_CLASS_MEMBER_NAMESPACE[];
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -220,6 +220,12 @@ static ValueNameToConverterVector value_name_to_converter = {
|
|||
auto class_type = value->cast<parse::ClassTypePtr>();
|
||||
return class_type->obj();
|
||||
}},
|
||||
// parse::MsClassObject
|
||||
{parse::MsClassObject::kTypeId,
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
auto ms_class_object = value->cast<parse::MsClassObjectPtr>();
|
||||
return ms_class_object->obj();
|
||||
}},
|
||||
// parse::InterpretedObject
|
||||
{parse::InterpretedObject::kTypeId,
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
|
|
|
@ -23,7 +23,8 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
|
|||
get_args, get_args_default_values, get_ast_namespace_symbol, get_operation_symbol,
|
||||
get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name, eval_script,
|
||||
expand_expr_statement, is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor,
|
||||
get_object_description, get_class_attr_namespace_symbol)
|
||||
get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
|
||||
get_ms_class_attr)
|
||||
|
||||
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
|
||||
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type',
|
||||
|
@ -32,4 +33,5 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
|
|||
'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser',
|
||||
'get_dataclass_attributes', 'get_dataclass_methods', 'get_dataclass_methods', 'get_scope_name',
|
||||
'eval_script', 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement',
|
||||
'generate_scope', 'get_operation_symbol', 'get_class_attr_namespace_symbol']
|
||||
'generate_scope', 'get_operation_symbol', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
|
||||
'get_ms_class_attr']
|
||||
|
|
|
@ -410,6 +410,30 @@ def get_dataclass_methods(cls):
|
|||
return methods
|
||||
|
||||
|
||||
def get_ms_class_name(cls):
|
||||
"""Get the name of the class instance decorated by ms_class."""
|
||||
# Check if cls is nn.Cell.
|
||||
if isinstance(cls, nn.Cell):
|
||||
raise TypeError(f"ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
|
||||
if isinstance(cls, type):
|
||||
name = cls.__name__
|
||||
else:
|
||||
name = cls.__class__.__name__
|
||||
# Get the name of cls.
|
||||
cls_name = cls.__module__ + '.' + name
|
||||
return cls_name
|
||||
|
||||
|
||||
def get_ms_class_attr(cls, name: str):
|
||||
"""Get attribute or method of ms_class obj."""
|
||||
# Don't take into account python magic methods and private variables.
|
||||
if name.startswith('_'):
|
||||
raise AttributeError(f"{name} is a private variable or magic method, which is not supported.")
|
||||
if not hasattr(cls, name):
|
||||
raise AttributeError(f"{cls} has no attribute: {name}.")
|
||||
return getattr(cls, name)
|
||||
|
||||
|
||||
def convert_to_ms_tensor(data):
|
||||
"""Convert C++ tensor to mindspore tensor."""
|
||||
return Tensor(data)
|
||||
|
@ -562,8 +586,8 @@ def eval_script(exp_str, params):
|
|||
local_params = _convert_data(local_params)
|
||||
obj = eval(exp_str, global_params, local_params)
|
||||
except Exception as e:
|
||||
error_info = f"When eval '{exp_str}' by using Fallback feature, an error occurred: " + str(e) + \
|
||||
". You can try to turn off the Fallback feature by 'export MS_DEV_ENABLE_FALLBACK=0'."
|
||||
error_info = f"When eval '{exp_str}' by using JIT Fallback feature, an error occurred: " + str(e) + \
|
||||
". You can try to turn off JIT Fallback feature by 'export MS_DEV_ENABLE_FALLBACK=0'."
|
||||
logger.error(error_info)
|
||||
raise e
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
"""Top-level reference to dtype of common module."""
|
||||
from . import dtype
|
||||
from .api import ms_function, ms_memory_recycle, _convert_data
|
||||
from .api import ms_function, ms_memory_recycle, ms_class, _convert_data
|
||||
from .dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \
|
||||
uint8, ubyte, uint16, ushort, uint32, uintc, uint64, uintp, float16, half, \
|
||||
float32, single, float64, double, bool_, float_, list_, tuple_, int_, \
|
||||
|
@ -54,7 +54,7 @@ __all__ = [
|
|||
|
||||
__all__.extend([
|
||||
"Tensor", "RowTensor", "SparseTensor", "COOTensor", "CSRTensor", # tensor
|
||||
'ms_function', # api
|
||||
'ms_function', 'ms_class', # api
|
||||
'Parameter', 'ParameterTuple', # parameter
|
||||
"dtype", "_convert_data",
|
||||
"set_seed", "get_seed", # random seed
|
||||
|
|
|
@ -20,6 +20,7 @@ import sys
|
|||
import os
|
||||
import time
|
||||
import ast
|
||||
import inspect
|
||||
import importlib
|
||||
from collections import OrderedDict
|
||||
from functools import wraps
|
||||
|
@ -439,12 +440,64 @@ def ms_function(fn=None, obj=None, input_signature=None):
|
|||
return wrap_mindspore(fn)
|
||||
return wrap_mindspore
|
||||
|
||||
|
||||
def ms_class(cls):
|
||||
"""
|
||||
Class decorator for user-defined classes.
|
||||
|
||||
This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods.
|
||||
|
||||
Args:
|
||||
cls (Class): User-defined class.
|
||||
|
||||
Returns:
|
||||
Class with __ms_class__ attribute.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import ms_class
|
||||
...
|
||||
>>> @ms_class
|
||||
>>> class UserDefinedNet:
|
||||
... def __init__(self):
|
||||
... self.value = 10
|
||||
...
|
||||
... def func(self, x):
|
||||
... return 2 * x
|
||||
...
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(Net, self).__init__()
|
||||
... self.net = UserDefinedNet()
|
||||
...
|
||||
... def construct(self, x):
|
||||
... out = self.net.value + self.net.func(x)
|
||||
... return out
|
||||
...
|
||||
>>> net = Net()
|
||||
>>> out = net(5)
|
||||
>>> print(out)
|
||||
20
|
||||
"""
|
||||
|
||||
# Check if cls is of type class.
|
||||
if not inspect.isclass(cls):
|
||||
raise TypeError(f'Decorator ms_class can only be used for class type, but got {cls}.')
|
||||
logger.info(f'Found ms_class: {cls}.')
|
||||
setattr(cls, '__ms_class__', True)
|
||||
return cls
|
||||
|
||||
|
||||
def is_pynative_parallel():
|
||||
run_mode = context.get_context('mode')
|
||||
parallel_mode = context.get_auto_parallel_context('parallel_mode')
|
||||
return run_mode == context.PYNATIVE_MODE and parallel_mode in (
|
||||
context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL)
|
||||
|
||||
|
||||
def _get_auto_split_param_names(parameter_layout_dict):
|
||||
auto_split_param_names = []
|
||||
for key, value in parameter_layout_dict.items():
|
||||
|
@ -899,4 +952,4 @@ def ms_memory_recycle():
|
|||
_cell_graph_executor = _CellGraphExecutor()
|
||||
_pynative_executor = _PynativeExecutor()
|
||||
|
||||
__all__ = ['ms_function', 'ms_memory_recycle']
|
||||
__all__ = ['ms_function', 'ms_memory_recycle', 'ms_class']
|
||||
|
|
|
@ -243,117 +243,6 @@ def test_scipy_module():
|
|||
print(out)
|
||||
|
||||
|
||||
def test_self_attr():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.attr in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Network, self).__init__()
|
||||
self.dim = 1
|
||||
|
||||
def construct(self, x):
|
||||
batch = x.shape[0]
|
||||
one = Tensor(np.ones([batch, self.dim]), mstype.float16)
|
||||
return one * x
|
||||
|
||||
net = Network()
|
||||
x = Tensor([1, 2], mstype.float32)
|
||||
out = net(x)
|
||||
print(out)
|
||||
|
||||
|
||||
def test_self_attr_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.attr in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def __init__(self, fn):
|
||||
super(Network, self).__init__()
|
||||
self.fn = fn
|
||||
|
||||
def construct(self):
|
||||
x = np.array([1, 2, 3])
|
||||
y = np.array([3, 4, 5])
|
||||
out = Tensor(self.fn(x, y))
|
||||
return out
|
||||
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
net = Network(fn)
|
||||
out = net()
|
||||
print(out)
|
||||
|
||||
|
||||
def test_self_attr_3():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.attr in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Network, self).__init__()
|
||||
self.value = [2, 2, 3]
|
||||
|
||||
def construct(self):
|
||||
x = np.array(self.value.count(2))
|
||||
return Tensor(x)
|
||||
|
||||
net = Network()
|
||||
out = net()
|
||||
print(out)
|
||||
|
||||
|
||||
def test_self_method():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.method in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def construct(self):
|
||||
x = np.array([1, 2, 3])
|
||||
y = np.array([3, 4, 5])
|
||||
out = Tensor(self.fn(x, y))
|
||||
return out
|
||||
|
||||
def fn(self, x, y):
|
||||
return x + y
|
||||
|
||||
net = Network()
|
||||
out = net()
|
||||
print(out)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
|
||||
def test_self_method_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.method in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def construct(self):
|
||||
x = np.array([1, 2, 3])
|
||||
y = np.array([3, 4, 5])
|
||||
z = self.fn(x, y)
|
||||
out = Tensor(z)
|
||||
return out
|
||||
|
||||
def fn(self, x, y):
|
||||
return x + y
|
||||
|
||||
net = Network()
|
||||
out = net()
|
||||
print(out)
|
||||
|
||||
|
||||
def test_probability_cauchy():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
|
|
@ -0,0 +1,398 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback """
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, context, ms_class
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_fallback_self_attr():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.attr in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Network, self).__init__()
|
||||
self.dim = 1
|
||||
|
||||
def construct(self, x):
|
||||
batch = x.shape[0]
|
||||
one = Tensor(np.ones([batch, self.dim]), mstype.float32)
|
||||
return one * x
|
||||
|
||||
net = Network()
|
||||
x = Tensor([1, 2], mstype.float32)
|
||||
out = net(x)
|
||||
expect = np.array([[1., 2.], [1., 2.]])
|
||||
assert np.allclose(out.asnumpy(), expect, 1.e-2, 1.e-2)
|
||||
|
||||
|
||||
def test_fallback_self_attr_fn():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.attr in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def __init__(self, fn):
|
||||
super(Network, self).__init__()
|
||||
self.fn = fn
|
||||
|
||||
def construct(self):
|
||||
x = np.array([1, 2, 3])
|
||||
y = np.array([3, 4, 5])
|
||||
out = Tensor(self.fn(x, y))
|
||||
return out
|
||||
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
net = Network(fn)
|
||||
out = net()
|
||||
expect = np.array([4, 6, 8])
|
||||
assert np.all(out.asnumpy() == expect)
|
||||
|
||||
|
||||
def test_fallback_self_attr_attr():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.attr in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Network, self).__init__()
|
||||
self.value = [2, 2, 3]
|
||||
|
||||
def construct(self):
|
||||
x = np.array(self.value.count(2))
|
||||
return Tensor(x)
|
||||
|
||||
net = Network()
|
||||
out = net()
|
||||
assert out == 2
|
||||
|
||||
|
||||
def test_fallback_self_method():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.method in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def construct(self):
|
||||
x = np.array([1, 2, 3])
|
||||
y = np.array([3, 4, 5])
|
||||
out = Tensor(self.fn(x, y))
|
||||
return out
|
||||
|
||||
def fn(self, x, y):
|
||||
return x + y
|
||||
|
||||
net = Network()
|
||||
out = net()
|
||||
expect = np.array([4, 6, 8])
|
||||
assert np.all(out.asnumpy() == expect)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
|
||||
def test_fallback_self_method_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.method in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def construct(self):
|
||||
x = np.array([1, 2, 3])
|
||||
y = np.array([3, 4, 5])
|
||||
z = self.fn(x, y)
|
||||
out = Tensor(z)
|
||||
return out
|
||||
|
||||
def fn(self, x, y):
|
||||
return x + y
|
||||
|
||||
net = Network()
|
||||
out = net()
|
||||
print(out)
|
||||
|
||||
|
||||
def test_fallback_class_attr():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test user-defined class attributes in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.number = 1
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = InnerNet()
|
||||
|
||||
def construct(self):
|
||||
out = self.inner_net.number
|
||||
return out
|
||||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert out == 1
|
||||
|
||||
|
||||
def test_fallback_class_method():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test user-defined class methods in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.val = 2
|
||||
|
||||
def act(self, x, y):
|
||||
return self.val * (x + y)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = InnerNet()
|
||||
|
||||
def construct(self):
|
||||
out = self.inner_net.act(1, 2)
|
||||
return out
|
||||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert out == 6
|
||||
|
||||
|
||||
def test_fallback_class_input_attr():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test user-defined class attributes in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.number = Tensor(np.array([1, 2, 3]))
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = net()
|
||||
|
||||
def construct(self):
|
||||
out = self.inner_net.number
|
||||
return out
|
||||
|
||||
net = Net(InnerNet)
|
||||
out = net()
|
||||
expect_res = np.array([1, 2, 3])
|
||||
assert np.all(out.asnumpy() == expect_res)
|
||||
|
||||
|
||||
def test_fallback_class_input_method():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test user-defined class methods in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.val = 2
|
||||
|
||||
def act(self, x, y):
|
||||
return self.val * (x + y)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = net()
|
||||
|
||||
def construct(self):
|
||||
out = self.inner_net.act(1, 2)
|
||||
return out
|
||||
|
||||
net = Net(InnerNet)
|
||||
out = net()
|
||||
assert out == 6
|
||||
|
||||
|
||||
def test_fallback_class_class_nested():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test nested ms_class in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class Inner:
|
||||
def __init__(self):
|
||||
self.number = 1
|
||||
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.inner = Inner()
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = InnerNet()
|
||||
|
||||
def construct(self):
|
||||
out = self.inner_net.inner.number
|
||||
return out
|
||||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert out == 1
|
||||
|
||||
|
||||
def test_fallback_class_cell_nested():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test nested ms_class and cell in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, val):
|
||||
super().__init__()
|
||||
self.val = val
|
||||
|
||||
def construct(self, x):
|
||||
return x + self.val
|
||||
|
||||
@ms_class
|
||||
class TrainNet():
|
||||
class Loss(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x):
|
||||
out = self.net(x)
|
||||
return out * 2
|
||||
|
||||
def __init__(self, net):
|
||||
self.net = net
|
||||
loss_net = self.Loss(self.net)
|
||||
self.number = loss_net(10)
|
||||
|
||||
global_net = Net(1)
|
||||
class LearnNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.value = TrainNet(global_net).number
|
||||
|
||||
def construct(self, x):
|
||||
return x + self.value
|
||||
|
||||
leanrn_net = LearnNet()
|
||||
out = leanrn_net(3)
|
||||
print(out)
|
||||
assert out == 25
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph yet')
|
||||
def test_fallback_class_isinstance():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test ms_class in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.number = 1
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = InnerNet()
|
||||
|
||||
def construct(self, x):
|
||||
if isinstance(self.inner_net, InnerNet):
|
||||
return x + 10
|
||||
return x
|
||||
|
||||
net = Net()
|
||||
out = net(5)
|
||||
assert out == 15
|
||||
|
||||
|
||||
def test_fallback_raise_error_not_class_type():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test ms_class in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
with pytest.raises(TypeError):
|
||||
@ms_class
|
||||
def func(x, y):
|
||||
return x + y
|
||||
|
||||
func(1, 2)
|
||||
|
||||
|
||||
def test_fallback_raise_error_not_class_instance():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test ms_class in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.number = 1
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self):
|
||||
out = InnerNet().number
|
||||
return out
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
net = Net()
|
||||
net()
|
||||
|
||||
|
||||
def test_fallback_raise_error_decorate_cell():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test ms_class in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
x = Tensor(1)
|
||||
net = Net()
|
||||
net(x)
|
Loading…
Reference in New Issue