!32613 Supports creating and calling instances of ms_class

Merge pull request !32613 from huangbingjian/class_dev
This commit is contained in:
i-robot 2022-04-16 06:14:29 +00:00 committed by Gitee
commit 4c3faa7f8f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
17 changed files with 837 additions and 520 deletions

View File

@ -658,13 +658,19 @@ ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) {
return class_type;
}
// Check the object is Cell Instance.
// Check if the object is Cell instance.
bool IsCellInstance(const py::object &obj) {
auto class_type = GetClassInstanceType(obj);
bool is_cell = (class_type == CLASS_INSTANCE_TYPE_CELL);
return is_cell;
}
// Check if the object is class type.
bool IsClassType(const py::object &obj) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
return python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_CLASS_TYPE, obj).cast<bool>();
}
// Create the python class instance.
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -44,6 +44,7 @@ ResolveTypeDef GetObjType(const py::object &obj);
ClassInstanceTypeDef GetClassInstanceType(const py::object &obj);
bool IsCellInstance(const py::object &obj);
bool IsClassType(const py::object &obj);
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs);
py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs);
void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name);

View File

@ -67,8 +67,8 @@ const char PYTHON_MOD_CREATE_INSTANCE[] = "create_instance";
const char PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE[] = "is_supported_create_instance_type";
const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes";
const char PYTHON_MOD_GET_DATACLASS_METHODS[] = "get_dataclass_methods";
const char PYTHON_MOD_IS_CLASS_TYPE[] = "is_class_type";
const char PYTHON_MOD_GET_MS_CLASS_NAME[] = "get_ms_class_name";
const char PYTHON_MOD_GET_MS_CLASS_ATTR[] = "get_ms_class_attr";
const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
const char PYTHON_MOD_GET_ATTR_NAMESPACE_SYMBOL[] = "get_class_attr_namespace_symbol";
const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol";

View File

@ -67,6 +67,7 @@ struct AnfDumpHandlerRegister {
}
} callback_register;
} // namespace
abstract::AbstractBasePtr ClassObject::ToAbstract() {
ClassPtr cls_ptr = ParseDataClass(obj());
auto abs_scalar = std::make_shared<abstract::AbstractScalar>();
@ -78,6 +79,24 @@ abstract::AbstractBasePtr ClassObject::ToAbstract() {
return std::make_shared<abstract::PartialAbstractClosure>(func_ptr, args_spec_list);
}
abstract::AbstractBasePtr MsClassObject::ToAbstract() {
auto abs_scalar =
std::make_shared<abstract::AbstractScalar>(shared_from_base<MsClassObject>(), std::make_shared<MsClassType>());
AbstractBasePtrList args_spec_list = {abs_scalar};
abstract::PrimitiveAbstractClosurePtr func_ptr = nullptr;
bool is_class_type = parse::data_converter::IsClassType(obj());
if (is_class_type) {
// Class type as func, such as Net(x, y)
func_ptr = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimCreateInstance);
} else {
// Class instance as func, such as net(x, y)
func_ptr = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimCallInstance);
}
auto ret_val = std::make_shared<abstract::PartialAbstractClosure>(func_ptr, args_spec_list);
ret_val->set_value_desc(ToString());
return ret_val;
}
static inline bool IsSupportedCreateInstanceType(const py::object &obj) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
auto res = python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE, obj);
@ -520,14 +539,15 @@ AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const MsCl
MS_LOG(DEBUG) << "Resolve ms_class obj (" << ms_class->name() << ") with attr " << attr << ".";
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
constexpr size_t prefix_index = 0;
if (attr.size() > 0 && attr[prefix_index] == '_') {
MS_LOG(EXCEPTION) << attr << " is a private variable or magic method, which is not supported.";
}
py::object cls_obj = ms_class->obj();
if (!py::hasattr(cls_obj, attr.c_str())) {
if (!py::hasattr(cls_obj, common::SafeCStr(attr))) {
MS_LOG(EXCEPTION) << ms_class->name() << " has not attribute: " << attr << ".";
}
const std::string fn = PYTHON_MOD_GET_MS_CLASS_ATTR;
const std::string module = "mindspore._extends.parse.parser";
py::object attr_obj = python_adapter::GetPyFn(module, fn)(cls_obj, attr);
py::object attr_obj = py::getattr(cls_obj, common::SafeCStr(attr));
AnfNodePtr res_node = ResolveObjectAndAddToManager(manager, attr_obj, node);
TraceManager::ClearParseOrResolveDebugInfo();
return res_node;

View File

@ -116,6 +116,7 @@ class PyObjectWrapper : public Named {
// the object that needs to be resolved
py::object obj_;
};
using PyObjectWrapperPtr = std::shared_ptr<PyObjectWrapper>;
// InterpretedObject class wrappers interpreted python object.
class InterpretedObject final : public PyObjectWrapper {
@ -137,9 +138,7 @@ class MsClassObject final : public PyObjectWrapper {
: PyObjectWrapper(obj, "MsClassObject: \'" + name + "\'") {}
~MsClassObject() override = default;
MS_DECLARE_PARENT(MsClassObject, PyObjectWrapper);
abstract::AbstractBasePtr ToAbstract() override {
return std::make_shared<abstract::AbstractScalar>(shared_from_base<MsClassObject>(), std::make_shared<External>());
}
abstract::AbstractBasePtr ToAbstract() override;
};
using MsClassObjectPtr = std::shared_ptr<MsClassObject>;

View File

@ -1304,6 +1304,35 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng
return StaticGetterInferred(converted_value, data_conf, out_conf);
}
EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value,
const ValuePtr &data_value, const ConfigPtr &data_conf,
const AnfNodeConfigPtr &out_conf) {
MS_EXCEPTION_IF_NULL(item_value);
MS_EXCEPTION_IF_NULL(data_value);
// Get the name of item.
if (!item_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "Expect a string, but got: " << item_value->ToString();
}
std::string item_name = item_value->cast<StringImmPtr>()->value();
// Get ms_class object.
if (!data_value->isa<parse::MsClassObject>()) {
MS_LOG(EXCEPTION) << "Expect a ms_class object, but got " << data_value->ToString();
}
auto ms_class = data_value->cast<parse::MsClassObjectPtr>();
MS_LOG(DEBUG) << "Resolve ms_class (" << ms_class->name() << ") with item " << item_name << ".";
// Get the attr/method of ms_class object.
auto out_node = out_conf->node();
FuncGraphPtr func_graph = out_node->func_graph();
auto new_node = ResolveMsClassWithAttr(func_graph->manager(), ms_class, item_name, out_node);
// Replace old node with the resolved new node in order list.
func_graph->ReplaceInOrder(out_node, new_node);
AnalysisEnginePtr eng = out_conf->engine();
MS_EXCEPTION_IF_NULL(eng);
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
return eng->ForwardConfig(out_conf, fn_conf);
}
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value,
const TypePtr &data_type, const ConfigPtr &data_conf,
const AnfNodeConfigPtr &out_conf) {
@ -1363,17 +1392,45 @@ int64_t GetResolveType(const TypePtr &data_type) {
return kResolveTypeFunction;
}
ValuePtr GetMsClassObject(const AbstractBasePtr &abs) {
if (!abs->isa<abstract::PartialAbstractClosure>()) {
return nullptr;
}
auto partial_abs = abs->cast<abstract::PartialAbstractClosurePtr>();
auto fn = partial_abs->fn();
if (!fn->isa<abstract::PrimitiveAbstractClosure>()) {
return nullptr;
}
// Check if type is kObjectTypeClass.
auto args = partial_abs->args();
if (args.size() > 0) {
constexpr size_t first_input_index = 0;
auto first_arg = args[first_input_index];
MS_EXCEPTION_IF_NULL(first_arg);
auto type = first_arg->BuildType();
MS_EXCEPTION_IF_NULL(type);
if (type->type_id() == kObjectTypeClass) {
return first_arg->BuildValue();
}
}
return nullptr;
}
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
// Inputs: namespace and its static function; or class and its member function
CheckArgsSize("StaticGetter", args_spec_list, 2);
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
MS_LOG(DEBUG) << "Args[0]: " << args_spec_list[0]->ToString();
MS_LOG(DEBUG) << "Args[1]: " << args_spec_list[1]->ToString();
TypePtr data_type = args_spec_list[0]->BuildType();
ValuePtr item_value = args_spec_list[1]->BuildValue();
constexpr size_t data_index = 0;
constexpr size_t item_index = 1;
auto data_args = args_spec_list[data_index];
auto item_args = args_spec_list[item_index];
MS_EXCEPTION_IF_NULL(data_args);
MS_EXCEPTION_IF_NULL(item_args);
MS_LOG(DEBUG) << "StaticGetter, data: " << data_args->ToString() << ", item: " << item_args->ToString();
TypePtr data_type = data_args->BuildType();
ValuePtr item_value = item_args->BuildValue();
ScopePtr scope = kDefaultScope;
if (out_conf != nullptr) {
scope = out_conf->node()->scope();
@ -1384,6 +1441,10 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
}
auto class_value = GetMsClassObject(data_args);
if (class_value != nullptr) {
return GetEvaluatedValueForMsClassAttrOrMethod(engine, item_value, class_value, data_conf, out_conf);
}
int64_t resolve_type = GetResolveType(data_type);
if (resolve_type == kResolveTypeUserDefineClass) {
return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf);
@ -1581,46 +1642,47 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override {
// Check the type parameter.
if (args_spec_list.empty()) {
MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
}
// Get the type parameter.
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
TypePtr type = args_spec_list[0]->GetTypeTrack();
constexpr size_t type_index = 0;
auto arg_class_type = args_spec_list[type_index];
MS_EXCEPTION_IF_NULL(arg_class_type);
TypePtr type = arg_class_type->GetTypeTrack();
MS_EXCEPTION_IF_NULL(type);
if (type->type_id() != kMetaTypeTypeType) {
MS_LOG(EXCEPTION) << "CreateInstanceEvaluator require first parameter should be an object of TypeType, but got "
<< type->ToString();
if (type->type_id() != kMetaTypeTypeType && type->type_id() != kObjectTypeClass) {
MS_LOG(EXCEPTION)
<< "CreateInstanceEvaluator require first parameter should be an object of TypeType or TypeClass, but got "
<< type->ToString();
}
ValuePtr value_track = args_spec_list[0]->GetValueTrack();
ValuePtr value_track = arg_class_type->GetValueTrack();
MS_EXCEPTION_IF_NULL(value_track);
std::shared_ptr<parse::PyObjectWrapper> type_obj = dyn_cast<parse::PyObjectWrapper>(value_track);
parse::PyObjectWrapperPtr type_obj = dyn_cast<parse::PyObjectWrapper>(value_track);
if (type_obj == nullptr) {
MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
}
if (!type_obj->isa<parse::ClassType>()) {
MS_LOG(EXCEPTION) << "CreateInstanceEvaluator the type_obj should be an object of ClassType, but got "
<< type_obj->ToString() << ".";
if (!type_obj->isa<parse::ClassType>() && !type_obj->isa<parse::MsClassObject>()) {
MS_LOG(EXCEPTION)
<< "CreateInstanceEvaluator the type_obj should be an object of ClassType or MsClassObject, but got "
<< type_obj->ToString() << ".";
}
auto class_type = type_obj->obj();
MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << ".";
MS_LOG(DEBUG) << "Get class type: " << type_obj->ToString() << ".";
// Get the create instance obj's parameters, `params` may contain tuple(args, kwargs).
py::tuple params = GetParameters(args_spec_list);
// Create class instance.
auto obj = parse::data_converter::CreatePythonObject(class_type, params);
if (py::isinstance<py::none>(obj)) {
MS_LOG(EXCEPTION) << "Create python object `" << py::str(class_type)
<< "` failed, only support to create \'Cell\' or \'Primitive\' object.";
<< "` failed, only support to create \'Cell\', \'Primitive\' or "
<< "user-defined Class decorated with \'ms_class\'.";
}
// Process the object.
MS_EXCEPTION_IF_NULL(out_conf->node());
TraceGuard guard(std::make_shared<TraceResolve>(out_conf->node()->debug_info()));
ValuePtr converted_ret = nullptr;
bool converted = parse::ConvertData(obj, &converted_ret, true);
@ -1628,7 +1690,6 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
MS_LOG(EXCEPTION) << "Convert the python object failed";
}
MS_EXCEPTION_IF_NULL(converted_ret);
if (converted_ret->isa<FuncGraph>()) {
AddToManager(engine, converted_ret->cast<FuncGraphPtr>());
}
@ -1664,6 +1725,63 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
}
};
class CallInstanceEvaluator : public TransitionPrimEvaluator {
public:
CallInstanceEvaluator() : TransitionPrimEvaluator("CallInstanceEvaluator") {}
~CallInstanceEvaluator() override = default;
MS_DECLARE_PARENT(CallInstanceEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override {
if (args_spec_list.empty()) {
MS_LOG(EXCEPTION) << "args_spec_list should not be empty.";
}
constexpr size_t cls_index = 0;
auto arg_cls = args_spec_list[cls_index];
MS_EXCEPTION_IF_NULL(arg_cls);
TypePtr type = arg_cls->GetTypeTrack();
MS_EXCEPTION_IF_NULL(type);
if (type->type_id() != kObjectTypeClass) {
MS_LOG(EXCEPTION) << "CallInstanceEvaluator require first parameter should be an object of TypeClass, but got "
<< type->ToString();
}
ValuePtr value_track = arg_cls->GetValueTrack();
MS_EXCEPTION_IF_NULL(value_track);
parse::MsClassObjectPtr ms_class = dyn_cast<parse::MsClassObject>(value_track);
if (ms_class == nullptr) {
MS_LOG(EXCEPTION) << "CallInstanceEvaluator only supports MsClassObject.";
}
// Call class instance, net(x, y) -> net.__call__(x, y)
py::object cls_obj = ms_class->obj();
const std::string call_func = "__call__";
if (!py::hasattr(cls_obj, common::SafeCStr(call_func))) {
MS_LOG(EXCEPTION) << ms_class->name() << " has no " << call_func << " function, please check the code.";
}
py::object call_obj = py::getattr(cls_obj, common::SafeCStr(call_func));
FuncGraphPtr call_func_graph = parse::ConvertToFuncGraph(call_obj);
if (call_func_graph == nullptr) {
MS_LOG(EXCEPTION) << "Parse python object " << call_func << " failed.";
}
FuncGraphManagerPtr manager = engine->func_graph_manager();
manager->AddFuncGraph(call_func_graph);
// Replace net with net.__call__
AnfNodePtr old_node = out_conf->node();
MS_EXCEPTION_IF_NULL(old_node);
CNodePtr old_cnode = dyn_cast<CNode>(old_node);
MS_EXCEPTION_IF_NULL(old_cnode);
std::vector<AnfNodePtr> inputs = {NewValueNode(call_func_graph)};
for (size_t i = 1; i < old_cnode->size(); i++) {
(void)inputs.emplace_back(old_cnode->input(i));
}
FuncGraphPtr func_graph = out_conf->func_graph();
auto new_cnode = func_graph->NewCNode(inputs);
// Continue to eval new_cnode.
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
return engine->ForwardConfig(out_conf, fn_conf);
}
};
class PyInterpretEvaluator : public TransitionPrimEvaluator {
public:
PyInterpretEvaluator() : TransitionPrimEvaluator("PyInterpretEvaluator") {}
@ -2085,6 +2203,7 @@ void InitPrimEvaluatorConstructors() {
constructor[prim::kPrimGetAttr] = std::make_shared<GetAttrEvaluator>();
constructor[prim::kPrimResolve] = std::make_shared<ResolveEvaluator>();
constructor[prim::kPrimCreateInstance] = std::make_shared<CreateInstanceEvaluator>();
constructor[prim::kPrimCallInstance] = std::make_shared<CallInstanceEvaluator>();
constructor[prim::kPrimPartial] = std::make_shared<PartialEvaluator>();
constructor[prim::kPrimPyInterpret] = std::make_shared<PyInterpretEvaluator>();
constructor[prim::kPrimMakeTuple] = std::make_shared<MakeTupleEvaluator>();

View File

@ -82,7 +82,7 @@ bool CheckAbstractScalar(const AnfNodePtr &node) {
if (abstract->isa<AbstractScalar>()) {
TypePtr type = abstract->GetTypeTrack();
MS_EXCEPTION_IF_NULL(type);
if (type->isa<EnvType>()) {
if (type->isa<EnvType>() || type->isa<MsClassType>()) {
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString() << ", node: " << node->DebugString();
}
if (type->isa<Problem>() || type->isa<External>()) {

View File

@ -345,6 +345,24 @@ class MS_CORE_API Problem final : public Type {
};
using ProblemPtr = std::shared_ptr<Problem>;
/// \brief MsClassType defines a type which is ms_class.
class MS_CORE_API MsClassType final : public Type {
public:
/// \brief The constructor of External.
///
/// \return The instance of External.
MsClassType() : Type(kObjectTypeClass) {}
/// \brief The destructor of External.
~MsClassType() override = default;
MS_DECLARE_PARENT(MsClassType, Type)
TypeId generic_type_id() const override { return kObjectTypeClass; }
TypePtr DeepCopy() const override { return std::make_shared<MsClassType>(); }
std::string DumpText() const override { return "MsClassType"; }
};
using MsClassTypePtr = std::shared_ptr<MsClassType>;
/// \brief External defines a type which is external.
class MS_CORE_API External final : public Type {
public:
@ -360,9 +378,6 @@ class MS_CORE_API External final : public Type {
TypeId generic_type_id() const override { return kMetaTypeExternal; }
TypePtr DeepCopy() const override { return std::make_shared<External>(); }
std::string DumpText() const override { return "ExternalType"; }
private:
TypePtr kind;
};
using ExternalPtr = std::shared_ptr<External>;

View File

@ -939,6 +939,7 @@ GVAR_DEF(PrimitivePtr, kPrimResolve, std::make_shared<Primitive>("resolve"));
GVAR_DEF(PrimitivePtr, kPrimEmbed, std::make_shared<Primitive>("embed"));
GVAR_DEF(PrimitivePtr, kPrimRefToEmbed, std::make_shared<Primitive>("RefToEmbed"));
GVAR_DEF(PrimitivePtr, kPrimCreateInstance, std::make_shared<Primitive>("create_instance"));
GVAR_DEF(PrimitivePtr, kPrimCallInstance, std::make_shared<Primitive>("call_instance"));
// Other miscellaneous
GVAR_DEF(PrimitivePtr, kPrimGetRefOrigin, std::make_shared<Primitive>("get_ref_origin"));

View File

@ -31,7 +31,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
"get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary",
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
"InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
"stop_gradient", "UpdateState", "Load", "Switch", "Print"};
"stop_gradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
#else
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem",
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
@ -40,7 +40,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
"identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key",
"get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "Debug", "col2im_v1",
"resolve", "BroadcastGradientArgs", "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
"stop_gradient", "UpdateState", "Load", "Switch", "Print"};
"stop_gradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
#endif
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather,
prim::kPrimMicroStepAllGather};

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,20 +18,18 @@ Interfaces for parser module in c++.
from .parser import (Parser, create_instance, is_supported_create_instance_type, generate_scope,
get_bprop_method_of_class, get_class_instance_type, get_class_member_namespace_symbol,
create_slice_obj, get_dataclass_attributes, get_dataclass_methods, get_obj_id,
get_module_namespace, get_obj_type, get_object_key, get_ast_type, get_node_type,
get_args, get_args_default_values, get_ast_namespace_symbol, get_operation_symbol,
get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name, eval_script,
expand_expr_statement, is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor,
get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
get_ms_class_attr)
create_slice_obj, get_obj_id, get_module_namespace, get_obj_type, get_object_key,
get_ast_type, get_node_type, get_args, get_args_default_values, get_ast_namespace_symbol,
get_operation_symbol, get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name,
eval_script, expand_expr_statement, is_class_member, parse_cb, resolve_symbol,
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
is_class_type, get_dataclass_attributes, get_dataclass_methods)
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type',
'get_args_default_values', 'get_ast_namespace_symbol', 'get_operation_namespace_symbol',
'get_args', 'get_obj_type', 'create_instance', 'is_supported_create_instance_type',
'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser',
'get_dataclass_attributes', 'get_dataclass_methods', 'get_dataclass_methods', 'get_scope_name',
'eval_script', 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement',
'generate_scope', 'get_operation_symbol', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
'get_ms_class_attr']
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
'create_slice_obj', 'get_obj_id', 'get_module_namespace', 'get_obj_type', 'get_object_key',
'get_ast_type', 'get_node_type', 'get_args', 'get_args_default_values', 'get_ast_namespace_symbol',
'get_operation_symbol', 'get_operation_namespace_symbol', 'get_parse_method_of_class', 'get_scope_name',
'eval_script', 'expand_expr_statement', 'is_class_member', 'parse_cb', 'resolve_symbol',
'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
'is_class_type', 'get_dataclass_attributes', 'get_dataclass_methods']

View File

@ -1,6 +1,6 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
#
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -118,7 +118,11 @@ class ClassMemberNamespace(Namespace):
except ValueError:
raise UnboundLocalError(name)
except KeyError:
logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.")
# Check if cls is user-defined class decorated with ms_class. If true, an exception will be thrown.
cls = d.__class__
if hasattr(cls, '__ms_class__'):
raise NotImplementedError(f"'{cls.__name__ }' object has no attribute or method: '{name}'.")
logger.info(f"'{cls.__name__ }' object has no attribute or method: '{name}', so will return None.")
raise AttributeError(name)
@ -142,5 +146,4 @@ class ClassAttrNamespace(Namespace):
except ValueError:
raise UnboundLocalError(name)
except KeyError:
logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.")
raise AttributeError(name)
raise AttributeError(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}'.")

View File

@ -1,6 +1,6 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
#
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -23,7 +23,6 @@ import hashlib
import inspect
import types
import importlib
from dataclasses import is_dataclass
from textwrap import dedent
import asttokens
@ -324,24 +323,26 @@ def get_class_instance_type(obj):
"""Get the class instance detail type."""
# check the obj type
logger.debug("Get the class type(%r)", obj)
class_type = CLASS_INSTANCE_TYPE_INVALID
if _is_class_instance(obj):
if isinstance(obj, nn.Cell):
class_type = CLASS_INSTANCE_TYPE_CELL
elif isinstance(obj, ops.Primitive):
class_type = CLASS_INSTANCE_TYPE_PRIMITIVE
# Add the other type base requirement
return class_type
if isinstance(obj, nn.Cell):
return CLASS_INSTANCE_TYPE_CELL
if isinstance(obj, ops.Primitive):
return CLASS_INSTANCE_TYPE_PRIMITIVE
return CLASS_INSTANCE_TYPE_INVALID
def _is_class_instance(obj):
"""Confirm the obj is class instance."""
return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_dataclass_instance(obj)
def _is_ms_class(obj):
"""Check if obj is ms_class object."""
return hasattr(obj, '__ms_class__')
def _is_dataclass_instance(obj):
"""Check whether a class is an instance of a dataclass (and not a dataclass itself)"""
return is_dataclass(obj) and not isinstance(obj, type)
return hasattr(obj, "__dataclass_fields__") and not isinstance(obj, type)
def _is_class_instance(obj):
"""Confirm the obj is class instance."""
return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_dataclass_instance(obj) or _is_ms_class(obj)
def _convert_tuple_to_args_kwargs(params):
@ -358,7 +359,7 @@ def _convert_tuple_to_args_kwargs(params):
def is_supported_create_instance_type(cls_type):
"""Check if cls_type is a supported instance type."""
return issubclass(cls_type, (nn.Cell, ops.Primitive))
return issubclass(cls_type, (nn.Cell, ops.Primitive)) or _is_ms_class(cls_type)
def create_instance(cls_type, params=None):
@ -440,28 +441,19 @@ def get_dataclass_methods(cls):
return methods
def is_class_type(cls):
"""Check if cls is a class type."""
return isinstance(cls, type)
def get_ms_class_name(cls):
"""Get the name of the class instance decorated by ms_class."""
# Check if cls is nn.Cell.
if isinstance(cls, nn.Cell):
raise TypeError(f"ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
if isinstance(cls, type):
name = cls.__name__
else:
name = cls.__class__.__name__
# Get the name of cls.
cls_name = cls.__module__ + '.' + name
return cls_name
def get_ms_class_attr(cls, name: str):
"""Get attribute or method of ms_class obj."""
# Don't take into account python magic methods and private variables.
if name.startswith('_'):
raise AttributeError(f"{name} is a private variable or magic method, which is not supported.")
if not hasattr(cls, name):
raise AttributeError(f"{cls} has no attribute: {name}.")
return getattr(cls, name)
return cls.__name__
return cls.__class__.__name__
def convert_to_ms_tensor(data):

View File

@ -1,6 +1,6 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
#
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -15,10 +15,7 @@
# limitations under the License.
# ============================================================================
"""standard_method"""
from dataclasses import dataclass
from mindspore import Tensor, Parameter, CSRTensor, COOTensor
from mindspore import Tensor, Parameter, CSRTensor, COOTensor, ms_class
from mindspore import dtype as mstype
from ..._checkparam import Validator as validator
@ -1828,16 +1825,16 @@ def float_floordiv(x, y):
#############
@dataclass(frozen=True)
@ms_class
class SequenceIterator:
"""
SequenceIterator is a util dataclass for iterating sequence object.
SequenceIterator is a util class for iterating sequence object.
Iterator to use for sequences like List, Array.
"""
idx: int
seq: list
def __init__(self, idx, seq):
self.idx = idx
self.seq = seq
@core(ignore_values=True)
def __ms_hasnext__(self):

View File

@ -0,0 +1,451 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test graph fallback """
import pytest
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor, context, ms_class
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_attr():
"""
Feature: JIT Fallback
Description: Access the attributes of user-defined classes decorated by ms_class.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.number = Tensor(1, dtype=mstype.int32)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self):
out = self.inner_net.number
return out
net = Net()
out = net()
assert out.asnumpy() == 1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_method():
"""
Feature: JIT Fallback
Description: Access the methods of user-defined classes decorated by ms_class.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.val = Tensor(2, dtype=mstype.int32)
def act(self, x, y):
return self.val * (x + y)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self, x, y):
out = self.inner_net.act(x, y)
return out
x = Tensor(2, dtype=mstype.int32)
y = Tensor(3, dtype=mstype.int32)
net = Net()
out = net(x, y)
assert out.asnumpy() == 10
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_call():
"""
Feature: JIT Fallback
Description: Call the __call__ function of user-defined classes decorated by ms_class.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self, val):
self.val = val
def __call__(self, x, y):
return self.val * (x + y)
class Net(nn.Cell):
def __init__(self, val):
super(Net, self).__init__()
self.inner_net = InnerNet(val)
def construct(self, x, y):
out = self.inner_net(x, y)
return out
val = Tensor(2, dtype=mstype.int32)
x = Tensor(3, dtype=mstype.int32)
y = Tensor(4, dtype=mstype.int32)
net = Net(val)
out = net(x, y)
assert out.asnumpy() == 14
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_input_attr():
"""
Feature: JIT Fallback
Description: Access the attributes of user-defined classes decorated by ms_class.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.number = Tensor(np.array([1, 2, 3]))
class Net(nn.Cell):
def __init__(self, net):
super(Net, self).__init__()
self.inner_net = net()
def construct(self):
out = self.inner_net.number
return out
net = Net(InnerNet)
out = net()
expect_res = np.array([1, 2, 3])
assert np.all(out.asnumpy() == expect_res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_input_method():
"""
Feature: JIT Fallback
Description: Access the methods of user-defined classes decorated by ms_class.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.val = Tensor(2, dtype=mstype.int32)
def act(self, x, y):
return self.val * (x + y)
class Net(nn.Cell):
def __init__(self, net):
super(Net, self).__init__()
self.inner_net = net()
def construct(self):
out = self.inner_net.act(1, 2)
return out
net = Net(InnerNet)
out = net()
assert out.asnumpy() == 6
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_class_nested():
"""
Feature: JIT Fallback
Description: Test nested ms_class in graph.
Expectation: No exception.
"""
@ms_class
class Inner:
def __init__(self):
self.number = Tensor(1, dtype=mstype.int32)
@ms_class
class InnerNet:
def __init__(self):
self.inner = Inner()
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self):
out = self.inner_net.inner.number
return out
net = Net()
out = net()
assert out.asnumpy() == 1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_cell_nested():
"""
Feature: JIT Fallback
Description: Test nested ms_class and cell in graph.
Expectation: No exception.
"""
class Net(nn.Cell):
def __init__(self, val):
super().__init__()
self.val = val
def construct(self, x):
return x + self.val
@ms_class
class TrainNet():
class Loss(nn.Cell):
def __init__(self, net):
super().__init__()
self.net = net
def construct(self, x):
out = self.net(x)
return out * 2
def __init__(self, net):
self.net = net
loss_net = self.Loss(self.net)
self.number = loss_net(10)
global_net = Net(1)
class LearnNet(nn.Cell):
def __init__(self):
super().__init__()
self.value = TrainNet(global_net).number
def construct(self, x):
return x + self.value
leanrn_net = LearnNet()
out = leanrn_net(3)
print(out)
assert out == 25
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_type_attr():
"""
Feature: JIT Fallback
Description: Access the attributes of class type.
Expectation: No exception.
"""
@ms_class
class InnerNet:
val = Tensor(2, dtype=mstype.int32)
def act(self, x, y):
return self.val * (x + y)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet
# Support accessing attributes of class type, but do not support
# accessing methods, e.g. self.inner_net.act(1, 2)
def construct(self):
out = self.inner_net.val
return out
net = Net()
out = net()
assert out == 2
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_create_instance_attr():
"""
Feature: JIT Fallback
Description: Access the attributes of the created class instance.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self, val):
self.number = val + 3
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet
def construct(self, x):
net = self.inner_net(x)
return net.number
net = Net()
out = net(2)
assert out == 5
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_create_instance_method():
"""
Feature: JIT Fallback
Description: Access the methods of the created class instance.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self, val):
self.number = val
def act(self, x, y):
return self.number * (x + y)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet
def construct(self, x, y, z):
net = self.inner_net(x)
return net.act(y, z)
x = 2
y = Tensor(2, dtype=mstype.int32)
z = Tensor(3, dtype=mstype.int32)
net = Net()
out = net(x, y, z)
assert out.asnumpy() == 10
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_create_instance_call():
"""
Feature: JIT Fallback
Description: Call the __call__ function of the created class instance.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self, number):
self.number = number
def __call__(self, x, y):
return self.number * (x + y)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet
def construct(self, x, y, z):
net = self.inner_net(x)
out = net(y, z)
return out
x = 2
y = Tensor(2, dtype=mstype.int32)
z = Tensor(3, dtype=mstype.int32)
net = Net()
out = net(x, y, z)
assert out == 10
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_raise_error_not_class_type():
"""
Feature: JIT Fallback
Description: Decorator ms_class cannot be used for non-class types.
Expectation: No exception.
"""
with pytest.raises(TypeError):
@ms_class
def func(x, y):
return x + y
func(1, 2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_raise_error_decorate_cell():
"""
Feature: JIT Fallback
Description: Decorator ms_class cannot be used for nn.Cell
Expectation: No exception.
"""
@ms_class
class Net(nn.Cell):
def construct(self, x):
return x
with pytest.raises(TypeError):
x = Tensor(1)
net = Net()
net(x)

View File

@ -1,416 +0,0 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test graph fallback """
import pytest
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor, context, ms_class, ms_function
from . import test_graph_fallback
context.set_context(mode=context.GRAPH_MODE)
def test_fallback_self_attr():
"""
Feature: JIT Fallback
Description: Test self.attr in graph.
Expectation: No exception.
"""
class Network(nn.Cell):
def __init__(self):
super(Network, self).__init__()
self.dim = 1
def construct(self, x):
batch = x.shape[0]
one = Tensor(np.ones([batch, self.dim]), mstype.float32)
return one * x
net = Network()
x = Tensor([1, 2], mstype.float32)
out = net(x)
expect = np.array([[1., 2.], [1., 2.]])
assert np.allclose(out.asnumpy(), expect, 1.e-2, 1.e-2)
def test_fallback_self_attr_fn():
"""
Feature: JIT Fallback
Description: Test self.attr in graph.
Expectation: No exception.
"""
class Network(nn.Cell):
def __init__(self, fn):
super(Network, self).__init__()
self.fn = fn
def construct(self):
x = np.array([1, 2, 3])
y = np.array([3, 4, 5])
out = Tensor(self.fn(x, y))
return out
def fn(x, y):
return x + y
net = Network(fn)
out = net()
expect = np.array([4, 6, 8])
assert np.all(out.asnumpy() == expect)
def test_fallback_self_attr_attr():
"""
Feature: JIT Fallback
Description: Test self.attr in graph.
Expectation: No exception.
"""
class Network(nn.Cell):
def __init__(self):
super(Network, self).__init__()
self.value = [2, 2, 3]
def construct(self):
x = np.array(self.value.count(2))
return Tensor(x)
net = Network()
out = net()
assert out == 2
def test_fallback_self_method():
"""
Feature: JIT Fallback
Description: Test self.method in graph.
Expectation: No exception.
"""
class Network(nn.Cell):
def construct(self):
x = np.array([1, 2, 3])
y = np.array([3, 4, 5])
out = Tensor(self.fn(x, y))
return out
def fn(self, x, y):
return x + y
net = Network()
out = net()
expect = np.array([4, 6, 8])
assert np.all(out.asnumpy() == expect)
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
def test_fallback_self_method_tensor():
"""
Feature: JIT Fallback
Description: Test self.method in graph.
Expectation: No exception.
"""
class Network(nn.Cell):
def construct(self):
x = np.array([1, 2, 3])
y = np.array([3, 4, 5])
z = self.fn(x, y)
out = Tensor(z)
return out
def fn(self, x, y):
return x + y
net = Network()
out = net()
print(out)
def test_fallback_import_modules():
"""
Feature: JIT Fallback
Description: add_func is defined in test_graph_fallback.py
Expectation: No exception.
"""
@ms_function
def use_imported_module(x, y):
out = test_graph_fallback.add_func(x, y)
return out
x = Tensor(2, dtype=mstype.int32)
y = Tensor(3, dtype=mstype.int32)
out = use_imported_module(x, y)
print(out)
def test_fallback_class_attr():
"""
Feature: JIT Fallback
Description: Test user-defined class attributes in graph.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.number = 1
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self):
out = self.inner_net.number
return out
net = Net()
out = net()
assert out == 1
def test_fallback_class_method():
"""
Feature: JIT Fallback
Description: Test user-defined class methods in graph.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.val = 2
def act(self, x, y):
return self.val * (x + y)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self):
out = self.inner_net.act(1, 2)
return out
net = Net()
out = net()
assert out == 6
def test_fallback_class_input_attr():
"""
Feature: JIT Fallback
Description: Test user-defined class attributes in graph.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.number = Tensor(np.array([1, 2, 3]))
class Net(nn.Cell):
def __init__(self, net):
super(Net, self).__init__()
self.inner_net = net()
def construct(self):
out = self.inner_net.number
return out
net = Net(InnerNet)
out = net()
expect_res = np.array([1, 2, 3])
assert np.all(out.asnumpy() == expect_res)
def test_fallback_class_input_method():
"""
Feature: JIT Fallback
Description: Test user-defined class methods in graph.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.val = 2
def act(self, x, y):
return self.val * (x + y)
class Net(nn.Cell):
def __init__(self, net):
super(Net, self).__init__()
self.inner_net = net()
def construct(self):
out = self.inner_net.act(1, 2)
return out
net = Net(InnerNet)
out = net()
assert out == 6
def test_fallback_class_class_nested():
"""
Feature: JIT Fallback
Description: Test nested ms_class in graph.
Expectation: No exception.
"""
@ms_class
class Inner:
def __init__(self):
self.number = 1
@ms_class
class InnerNet:
def __init__(self):
self.inner = Inner()
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self):
out = self.inner_net.inner.number
return out
net = Net()
out = net()
assert out == 1
def test_fallback_class_cell_nested():
"""
Feature: JIT Fallback
Description: Test nested ms_class and cell in graph.
Expectation: No exception.
"""
class Net(nn.Cell):
def __init__(self, val):
super().__init__()
self.val = val
def construct(self, x):
return x + self.val
@ms_class
class TrainNet():
class Loss(nn.Cell):
def __init__(self, net):
super().__init__()
self.net = net
def construct(self, x):
out = self.net(x)
return out * 2
def __init__(self, net):
self.net = net
loss_net = self.Loss(self.net)
self.number = loss_net(10)
global_net = Net(1)
class LearnNet(nn.Cell):
def __init__(self):
super().__init__()
self.value = TrainNet(global_net).number
def construct(self, x):
return x + self.value
leanrn_net = LearnNet()
out = leanrn_net(3)
print(out)
assert out == 25
@pytest.mark.skip(reason='Not support in graph yet')
def test_fallback_class_isinstance():
"""
Feature: JIT Fallback
Description: Test ms_class in graph.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.number = 1
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self, x):
if isinstance(self.inner_net, InnerNet):
return x + 10
return x
net = Net()
out = net(5)
assert out == 15
def test_fallback_raise_error_not_class_type():
"""
Feature: JIT Fallback
Description: Test ms_class in graph.
Expectation: No exception.
"""
with pytest.raises(TypeError):
@ms_class
def func(x, y):
return x + y
func(1, 2)
def test_fallback_raise_error_not_class_instance():
"""
Feature: JIT Fallback
Description: Test ms_class in graph.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.number = 1
class Net(nn.Cell):
def construct(self):
out = InnerNet().number
return out
with pytest.raises(ValueError):
net = Net()
net()
def test_fallback_raise_error_decorate_cell():
"""
Feature: JIT Fallback
Description: Test ms_class in graph.
Expectation: No exception.
"""
@ms_class
class Net(nn.Cell):
def construct(self, x):
return x
with pytest.raises(TypeError):
x = Tensor(1)
net = Net()
net(x)

View File

@ -0,0 +1,131 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test graph fallback """
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor, context, ms_function
from . import test_graph_fallback
context.set_context(mode=context.GRAPH_MODE)
def test_fallback_self_attr():
"""
Feature: JIT Fallback
Description: Use self.attr in expressions supported by JIT Fallback.
Expectation: No exception.
"""
class Network(nn.Cell):
def __init__(self):
super(Network, self).__init__()
self.dim = 1
def construct(self, x):
batch = x.shape[0]
one = Tensor(np.ones([batch, self.dim]), mstype.float32)
return one * x
net = Network()
x = Tensor([1, 2], mstype.float32)
out = net(x)
expect = np.array([[1., 2.], [1., 2.]])
assert np.allclose(out.asnumpy(), expect, 1.e-2, 1.e-2)
def test_fallback_self_attr_fn():
"""
Feature: JIT Fallback
Description: Use self.attr of type function in expressions supported by JIT Fallback.
Expectation: No exception.
"""
class Network(nn.Cell):
def __init__(self, fn):
super(Network, self).__init__()
self.fn = fn
def construct(self):
x = np.array([1, 2, 3])
y = np.array([3, 4, 5])
out = Tensor(self.fn(x, y))
return out
def fn(x, y):
return x + y
net = Network(fn)
out = net()
expect = np.array([4, 6, 8])
assert np.all(out.asnumpy() == expect)
def test_fallback_self_attr_attr():
"""
Feature: JIT Fallback
Description: In expressions supported by JIT Fallback, use the attribute of self.attr.
Expectation: No exception.
"""
class Network(nn.Cell):
def __init__(self):
super(Network, self).__init__()
self.value = [2, 2, 3]
def construct(self):
x = np.array(self.value.count(2))
return Tensor(x)
net = Network()
out = net()
assert out == 2
def test_fallback_self_method():
"""
Feature: JIT Fallback
Description: Use self.method in expressions supported by JIT Fallback.
Expectation: No exception.
"""
class Network(nn.Cell):
def construct(self):
x = np.array([1, 2, 3])
y = np.array([3, 4, 5])
out = Tensor(self.fn(x, y))
return out
def fn(self, x, y):
return x + y
net = Network()
out = net()
expect = np.array([4, 6, 8])
assert np.all(out.asnumpy() == expect)
def test_fallback_import_modules():
"""
Feature: JIT Fallback
Description: Check whether the call to the third-party library is correct. It has nothing to do with class.
Expectation: No exception.
"""
@ms_function
def use_imported_module(x, y):
out = test_graph_fallback.add_func(x, y)
return out
x = Tensor(2, dtype=mstype.int32)
y = Tensor(3, dtype=mstype.int32)
out = use_imported_module(x, y)
print(out)