!32613 Supports creating and calling instances of ms_class

Merge pull request !32613 from huangbingjian/class_dev
This commit is contained in:
i-robot 2022-04-16 06:14:29 +00:00 committed by Gitee
commit 4c3faa7f8f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
17 changed files with 837 additions and 520 deletions

View File

@ -658,13 +658,19 @@ ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) {
return class_type; return class_type;
} }
// Check the object is Cell Instance. // Check if the object is Cell instance.
bool IsCellInstance(const py::object &obj) { bool IsCellInstance(const py::object &obj) {
auto class_type = GetClassInstanceType(obj); auto class_type = GetClassInstanceType(obj);
bool is_cell = (class_type == CLASS_INSTANCE_TYPE_CELL); bool is_cell = (class_type == CLASS_INSTANCE_TYPE_CELL);
return is_cell; return is_cell;
} }
// Check if the object is class type.
bool IsClassType(const py::object &obj) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
return python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_CLASS_TYPE, obj).cast<bool>();
}
// Create the python class instance. // Create the python class instance.
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs) { py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);

View File

@ -1,7 +1,7 @@
/** /**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
* *
* Copyright 2019-2021 Huawei Technologies Co., Ltd * Copyright 2019-2022 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -44,6 +44,7 @@ ResolveTypeDef GetObjType(const py::object &obj);
ClassInstanceTypeDef GetClassInstanceType(const py::object &obj); ClassInstanceTypeDef GetClassInstanceType(const py::object &obj);
bool IsCellInstance(const py::object &obj); bool IsCellInstance(const py::object &obj);
bool IsClassType(const py::object &obj);
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs); py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs);
py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs); py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs);
void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name); void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name);

View File

@ -67,8 +67,8 @@ const char PYTHON_MOD_CREATE_INSTANCE[] = "create_instance";
const char PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE[] = "is_supported_create_instance_type"; const char PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE[] = "is_supported_create_instance_type";
const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes"; const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes";
const char PYTHON_MOD_GET_DATACLASS_METHODS[] = "get_dataclass_methods"; const char PYTHON_MOD_GET_DATACLASS_METHODS[] = "get_dataclass_methods";
const char PYTHON_MOD_IS_CLASS_TYPE[] = "is_class_type";
const char PYTHON_MOD_GET_MS_CLASS_NAME[] = "get_ms_class_name"; const char PYTHON_MOD_GET_MS_CLASS_NAME[] = "get_ms_class_name";
const char PYTHON_MOD_GET_MS_CLASS_ATTR[] = "get_ms_class_attr";
const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace"; const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
const char PYTHON_MOD_GET_ATTR_NAMESPACE_SYMBOL[] = "get_class_attr_namespace_symbol"; const char PYTHON_MOD_GET_ATTR_NAMESPACE_SYMBOL[] = "get_class_attr_namespace_symbol";
const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol"; const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol";

View File

@ -67,6 +67,7 @@ struct AnfDumpHandlerRegister {
} }
} callback_register; } callback_register;
} // namespace } // namespace
abstract::AbstractBasePtr ClassObject::ToAbstract() { abstract::AbstractBasePtr ClassObject::ToAbstract() {
ClassPtr cls_ptr = ParseDataClass(obj()); ClassPtr cls_ptr = ParseDataClass(obj());
auto abs_scalar = std::make_shared<abstract::AbstractScalar>(); auto abs_scalar = std::make_shared<abstract::AbstractScalar>();
@ -78,6 +79,24 @@ abstract::AbstractBasePtr ClassObject::ToAbstract() {
return std::make_shared<abstract::PartialAbstractClosure>(func_ptr, args_spec_list); return std::make_shared<abstract::PartialAbstractClosure>(func_ptr, args_spec_list);
} }
abstract::AbstractBasePtr MsClassObject::ToAbstract() {
auto abs_scalar =
std::make_shared<abstract::AbstractScalar>(shared_from_base<MsClassObject>(), std::make_shared<MsClassType>());
AbstractBasePtrList args_spec_list = {abs_scalar};
abstract::PrimitiveAbstractClosurePtr func_ptr = nullptr;
bool is_class_type = parse::data_converter::IsClassType(obj());
if (is_class_type) {
// Class type as func, such as Net(x, y)
func_ptr = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimCreateInstance);
} else {
// Class instance as func, such as net(x, y)
func_ptr = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimCallInstance);
}
auto ret_val = std::make_shared<abstract::PartialAbstractClosure>(func_ptr, args_spec_list);
ret_val->set_value_desc(ToString());
return ret_val;
}
static inline bool IsSupportedCreateInstanceType(const py::object &obj) { static inline bool IsSupportedCreateInstanceType(const py::object &obj) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
auto res = python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE, obj); auto res = python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE, obj);
@ -520,14 +539,15 @@ AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const MsCl
MS_LOG(DEBUG) << "Resolve ms_class obj (" << ms_class->name() << ") with attr " << attr << "."; MS_LOG(DEBUG) << "Resolve ms_class obj (" << ms_class->name() << ") with attr " << attr << ".";
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info())); TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
constexpr size_t prefix_index = 0;
if (attr.size() > 0 && attr[prefix_index] == '_') {
MS_LOG(EXCEPTION) << attr << " is a private variable or magic method, which is not supported.";
}
py::object cls_obj = ms_class->obj(); py::object cls_obj = ms_class->obj();
if (!py::hasattr(cls_obj, attr.c_str())) { if (!py::hasattr(cls_obj, common::SafeCStr(attr))) {
MS_LOG(EXCEPTION) << ms_class->name() << " has not attribute: " << attr << "."; MS_LOG(EXCEPTION) << ms_class->name() << " has not attribute: " << attr << ".";
} }
py::object attr_obj = py::getattr(cls_obj, common::SafeCStr(attr));
const std::string fn = PYTHON_MOD_GET_MS_CLASS_ATTR;
const std::string module = "mindspore._extends.parse.parser";
py::object attr_obj = python_adapter::GetPyFn(module, fn)(cls_obj, attr);
AnfNodePtr res_node = ResolveObjectAndAddToManager(manager, attr_obj, node); AnfNodePtr res_node = ResolveObjectAndAddToManager(manager, attr_obj, node);
TraceManager::ClearParseOrResolveDebugInfo(); TraceManager::ClearParseOrResolveDebugInfo();
return res_node; return res_node;

View File

@ -116,6 +116,7 @@ class PyObjectWrapper : public Named {
// the object that needs to be resolved // the object that needs to be resolved
py::object obj_; py::object obj_;
}; };
using PyObjectWrapperPtr = std::shared_ptr<PyObjectWrapper>;
// InterpretedObject class wrappers interpreted python object. // InterpretedObject class wrappers interpreted python object.
class InterpretedObject final : public PyObjectWrapper { class InterpretedObject final : public PyObjectWrapper {
@ -137,9 +138,7 @@ class MsClassObject final : public PyObjectWrapper {
: PyObjectWrapper(obj, "MsClassObject: \'" + name + "\'") {} : PyObjectWrapper(obj, "MsClassObject: \'" + name + "\'") {}
~MsClassObject() override = default; ~MsClassObject() override = default;
MS_DECLARE_PARENT(MsClassObject, PyObjectWrapper); MS_DECLARE_PARENT(MsClassObject, PyObjectWrapper);
abstract::AbstractBasePtr ToAbstract() override { abstract::AbstractBasePtr ToAbstract() override;
return std::make_shared<abstract::AbstractScalar>(shared_from_base<MsClassObject>(), std::make_shared<External>());
}
}; };
using MsClassObjectPtr = std::shared_ptr<MsClassObject>; using MsClassObjectPtr = std::shared_ptr<MsClassObject>;

View File

@ -1304,6 +1304,35 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng
return StaticGetterInferred(converted_value, data_conf, out_conf); return StaticGetterInferred(converted_value, data_conf, out_conf);
} }
EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value,
const ValuePtr &data_value, const ConfigPtr &data_conf,
const AnfNodeConfigPtr &out_conf) {
MS_EXCEPTION_IF_NULL(item_value);
MS_EXCEPTION_IF_NULL(data_value);
// Get the name of item.
if (!item_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << "Expect a string, but got: " << item_value->ToString();
}
std::string item_name = item_value->cast<StringImmPtr>()->value();
// Get ms_class object.
if (!data_value->isa<parse::MsClassObject>()) {
MS_LOG(EXCEPTION) << "Expect a ms_class object, but got " << data_value->ToString();
}
auto ms_class = data_value->cast<parse::MsClassObjectPtr>();
MS_LOG(DEBUG) << "Resolve ms_class (" << ms_class->name() << ") with item " << item_name << ".";
// Get the attr/method of ms_class object.
auto out_node = out_conf->node();
FuncGraphPtr func_graph = out_node->func_graph();
auto new_node = ResolveMsClassWithAttr(func_graph->manager(), ms_class, item_name, out_node);
// Replace old node with the resolved new node in order list.
func_graph->ReplaceInOrder(out_node, new_node);
AnalysisEnginePtr eng = out_conf->engine();
MS_EXCEPTION_IF_NULL(eng);
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
return eng->ForwardConfig(out_conf, fn_conf);
}
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value, EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value,
const TypePtr &data_type, const ConfigPtr &data_conf, const TypePtr &data_type, const ConfigPtr &data_conf,
const AnfNodeConfigPtr &out_conf) { const AnfNodeConfigPtr &out_conf) {
@ -1363,17 +1392,45 @@ int64_t GetResolveType(const TypePtr &data_type) {
return kResolveTypeFunction; return kResolveTypeFunction;
} }
ValuePtr GetMsClassObject(const AbstractBasePtr &abs) {
if (!abs->isa<abstract::PartialAbstractClosure>()) {
return nullptr;
}
auto partial_abs = abs->cast<abstract::PartialAbstractClosurePtr>();
auto fn = partial_abs->fn();
if (!fn->isa<abstract::PrimitiveAbstractClosure>()) {
return nullptr;
}
// Check if type is kObjectTypeClass.
auto args = partial_abs->args();
if (args.size() > 0) {
constexpr size_t first_input_index = 0;
auto first_arg = args[first_input_index];
MS_EXCEPTION_IF_NULL(first_arg);
auto type = first_arg->BuildType();
MS_EXCEPTION_IF_NULL(type);
if (type->type_id() == kObjectTypeClass) {
return first_arg->BuildValue();
}
}
return nullptr;
}
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
// Inputs: namespace and its static function; or class and its member function // Inputs: namespace and its static function; or class and its member function
CheckArgsSize("StaticGetter", args_spec_list, 2); CheckArgsSize("StaticGetter", args_spec_list, 2);
MS_EXCEPTION_IF_NULL(args_spec_list[0]); constexpr size_t data_index = 0;
MS_EXCEPTION_IF_NULL(args_spec_list[1]); constexpr size_t item_index = 1;
MS_LOG(DEBUG) << "Args[0]: " << args_spec_list[0]->ToString(); auto data_args = args_spec_list[data_index];
MS_LOG(DEBUG) << "Args[1]: " << args_spec_list[1]->ToString(); auto item_args = args_spec_list[item_index];
TypePtr data_type = args_spec_list[0]->BuildType(); MS_EXCEPTION_IF_NULL(data_args);
ValuePtr item_value = args_spec_list[1]->BuildValue(); MS_EXCEPTION_IF_NULL(item_args);
MS_LOG(DEBUG) << "StaticGetter, data: " << data_args->ToString() << ", item: " << item_args->ToString();
TypePtr data_type = data_args->BuildType();
ValuePtr item_value = item_args->BuildValue();
ScopePtr scope = kDefaultScope; ScopePtr scope = kDefaultScope;
if (out_conf != nullptr) { if (out_conf != nullptr) {
scope = out_conf->node()->scope(); scope = out_conf->node()->scope();
@ -1384,6 +1441,10 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString(); MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
} }
auto class_value = GetMsClassObject(data_args);
if (class_value != nullptr) {
return GetEvaluatedValueForMsClassAttrOrMethod(engine, item_value, class_value, data_conf, out_conf);
}
int64_t resolve_type = GetResolveType(data_type); int64_t resolve_type = GetResolveType(data_type);
if (resolve_type == kResolveTypeUserDefineClass) { if (resolve_type == kResolveTypeUserDefineClass) {
return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf); return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf);
@ -1581,46 +1642,47 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator); MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &, EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override { const AnfNodeConfigPtr &out_conf) override {
// Check the type parameter.
if (args_spec_list.empty()) { if (args_spec_list.empty()) {
MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty"; MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
} }
constexpr size_t type_index = 0;
// Get the type parameter. auto arg_class_type = args_spec_list[type_index];
MS_EXCEPTION_IF_NULL(args_spec_list[0]); MS_EXCEPTION_IF_NULL(arg_class_type);
TypePtr type = args_spec_list[0]->GetTypeTrack(); TypePtr type = arg_class_type->GetTypeTrack();
MS_EXCEPTION_IF_NULL(type); MS_EXCEPTION_IF_NULL(type);
if (type->type_id() != kMetaTypeTypeType) { if (type->type_id() != kMetaTypeTypeType && type->type_id() != kObjectTypeClass) {
MS_LOG(EXCEPTION) << "CreateInstanceEvaluator require first parameter should be an object of TypeType, but got " MS_LOG(EXCEPTION)
<< "CreateInstanceEvaluator require first parameter should be an object of TypeType or TypeClass, but got "
<< type->ToString(); << type->ToString();
} }
ValuePtr value_track = args_spec_list[0]->GetValueTrack(); ValuePtr value_track = arg_class_type->GetValueTrack();
MS_EXCEPTION_IF_NULL(value_track); MS_EXCEPTION_IF_NULL(value_track);
parse::PyObjectWrapperPtr type_obj = dyn_cast<parse::PyObjectWrapper>(value_track);
std::shared_ptr<parse::PyObjectWrapper> type_obj = dyn_cast<parse::PyObjectWrapper>(value_track);
if (type_obj == nullptr) { if (type_obj == nullptr) {
MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << "."; MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
} }
if (!type_obj->isa<parse::ClassType>() && !type_obj->isa<parse::MsClassObject>()) {
if (!type_obj->isa<parse::ClassType>()) { MS_LOG(EXCEPTION)
MS_LOG(EXCEPTION) << "CreateInstanceEvaluator the type_obj should be an object of ClassType, but got " << "CreateInstanceEvaluator the type_obj should be an object of ClassType or MsClassObject, but got "
<< type_obj->ToString() << "."; << type_obj->ToString() << ".";
} }
auto class_type = type_obj->obj(); auto class_type = type_obj->obj();
MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << "."; MS_LOG(DEBUG) << "Get class type: " << type_obj->ToString() << ".";
// Get the create instance obj's parameters, `params` may contain tuple(args, kwargs). // Get the create instance obj's parameters, `params` may contain tuple(args, kwargs).
py::tuple params = GetParameters(args_spec_list); py::tuple params = GetParameters(args_spec_list);
// Create class instance. // Create class instance.
auto obj = parse::data_converter::CreatePythonObject(class_type, params); auto obj = parse::data_converter::CreatePythonObject(class_type, params);
if (py::isinstance<py::none>(obj)) { if (py::isinstance<py::none>(obj)) {
MS_LOG(EXCEPTION) << "Create python object `" << py::str(class_type) MS_LOG(EXCEPTION) << "Create python object `" << py::str(class_type)
<< "` failed, only support to create \'Cell\' or \'Primitive\' object."; << "` failed, only support to create \'Cell\', \'Primitive\' or "
<< "user-defined Class decorated with \'ms_class\'.";
} }
// Process the object. // Process the object.
MS_EXCEPTION_IF_NULL(out_conf->node());
TraceGuard guard(std::make_shared<TraceResolve>(out_conf->node()->debug_info())); TraceGuard guard(std::make_shared<TraceResolve>(out_conf->node()->debug_info()));
ValuePtr converted_ret = nullptr; ValuePtr converted_ret = nullptr;
bool converted = parse::ConvertData(obj, &converted_ret, true); bool converted = parse::ConvertData(obj, &converted_ret, true);
@ -1628,7 +1690,6 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
MS_LOG(EXCEPTION) << "Convert the python object failed"; MS_LOG(EXCEPTION) << "Convert the python object failed";
} }
MS_EXCEPTION_IF_NULL(converted_ret); MS_EXCEPTION_IF_NULL(converted_ret);
if (converted_ret->isa<FuncGraph>()) { if (converted_ret->isa<FuncGraph>()) {
AddToManager(engine, converted_ret->cast<FuncGraphPtr>()); AddToManager(engine, converted_ret->cast<FuncGraphPtr>());
} }
@ -1664,6 +1725,63 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
} }
}; };
class CallInstanceEvaluator : public TransitionPrimEvaluator {
public:
CallInstanceEvaluator() : TransitionPrimEvaluator("CallInstanceEvaluator") {}
~CallInstanceEvaluator() override = default;
MS_DECLARE_PARENT(CallInstanceEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override {
if (args_spec_list.empty()) {
MS_LOG(EXCEPTION) << "args_spec_list should not be empty.";
}
constexpr size_t cls_index = 0;
auto arg_cls = args_spec_list[cls_index];
MS_EXCEPTION_IF_NULL(arg_cls);
TypePtr type = arg_cls->GetTypeTrack();
MS_EXCEPTION_IF_NULL(type);
if (type->type_id() != kObjectTypeClass) {
MS_LOG(EXCEPTION) << "CallInstanceEvaluator require first parameter should be an object of TypeClass, but got "
<< type->ToString();
}
ValuePtr value_track = arg_cls->GetValueTrack();
MS_EXCEPTION_IF_NULL(value_track);
parse::MsClassObjectPtr ms_class = dyn_cast<parse::MsClassObject>(value_track);
if (ms_class == nullptr) {
MS_LOG(EXCEPTION) << "CallInstanceEvaluator only supports MsClassObject.";
}
// Call class instance, net(x, y) -> net.__call__(x, y)
py::object cls_obj = ms_class->obj();
const std::string call_func = "__call__";
if (!py::hasattr(cls_obj, common::SafeCStr(call_func))) {
MS_LOG(EXCEPTION) << ms_class->name() << " has no " << call_func << " function, please check the code.";
}
py::object call_obj = py::getattr(cls_obj, common::SafeCStr(call_func));
FuncGraphPtr call_func_graph = parse::ConvertToFuncGraph(call_obj);
if (call_func_graph == nullptr) {
MS_LOG(EXCEPTION) << "Parse python object " << call_func << " failed.";
}
FuncGraphManagerPtr manager = engine->func_graph_manager();
manager->AddFuncGraph(call_func_graph);
// Replace net with net.__call__
AnfNodePtr old_node = out_conf->node();
MS_EXCEPTION_IF_NULL(old_node);
CNodePtr old_cnode = dyn_cast<CNode>(old_node);
MS_EXCEPTION_IF_NULL(old_cnode);
std::vector<AnfNodePtr> inputs = {NewValueNode(call_func_graph)};
for (size_t i = 1; i < old_cnode->size(); i++) {
(void)inputs.emplace_back(old_cnode->input(i));
}
FuncGraphPtr func_graph = out_conf->func_graph();
auto new_cnode = func_graph->NewCNode(inputs);
// Continue to eval new_cnode.
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
return engine->ForwardConfig(out_conf, fn_conf);
}
};
class PyInterpretEvaluator : public TransitionPrimEvaluator { class PyInterpretEvaluator : public TransitionPrimEvaluator {
public: public:
PyInterpretEvaluator() : TransitionPrimEvaluator("PyInterpretEvaluator") {} PyInterpretEvaluator() : TransitionPrimEvaluator("PyInterpretEvaluator") {}
@ -2085,6 +2203,7 @@ void InitPrimEvaluatorConstructors() {
constructor[prim::kPrimGetAttr] = std::make_shared<GetAttrEvaluator>(); constructor[prim::kPrimGetAttr] = std::make_shared<GetAttrEvaluator>();
constructor[prim::kPrimResolve] = std::make_shared<ResolveEvaluator>(); constructor[prim::kPrimResolve] = std::make_shared<ResolveEvaluator>();
constructor[prim::kPrimCreateInstance] = std::make_shared<CreateInstanceEvaluator>(); constructor[prim::kPrimCreateInstance] = std::make_shared<CreateInstanceEvaluator>();
constructor[prim::kPrimCallInstance] = std::make_shared<CallInstanceEvaluator>();
constructor[prim::kPrimPartial] = std::make_shared<PartialEvaluator>(); constructor[prim::kPrimPartial] = std::make_shared<PartialEvaluator>();
constructor[prim::kPrimPyInterpret] = std::make_shared<PyInterpretEvaluator>(); constructor[prim::kPrimPyInterpret] = std::make_shared<PyInterpretEvaluator>();
constructor[prim::kPrimMakeTuple] = std::make_shared<MakeTupleEvaluator>(); constructor[prim::kPrimMakeTuple] = std::make_shared<MakeTupleEvaluator>();

View File

@ -82,7 +82,7 @@ bool CheckAbstractScalar(const AnfNodePtr &node) {
if (abstract->isa<AbstractScalar>()) { if (abstract->isa<AbstractScalar>()) {
TypePtr type = abstract->GetTypeTrack(); TypePtr type = abstract->GetTypeTrack();
MS_EXCEPTION_IF_NULL(type); MS_EXCEPTION_IF_NULL(type);
if (type->isa<EnvType>()) { if (type->isa<EnvType>() || type->isa<MsClassType>()) {
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString() << ", node: " << node->DebugString(); MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString() << ", node: " << node->DebugString();
} }
if (type->isa<Problem>() || type->isa<External>()) { if (type->isa<Problem>() || type->isa<External>()) {

View File

@ -345,6 +345,24 @@ class MS_CORE_API Problem final : public Type {
}; };
using ProblemPtr = std::shared_ptr<Problem>; using ProblemPtr = std::shared_ptr<Problem>;
/// \brief MsClassType defines a type which is ms_class.
class MS_CORE_API MsClassType final : public Type {
public:
/// \brief The constructor of External.
///
/// \return The instance of External.
MsClassType() : Type(kObjectTypeClass) {}
/// \brief The destructor of External.
~MsClassType() override = default;
MS_DECLARE_PARENT(MsClassType, Type)
TypeId generic_type_id() const override { return kObjectTypeClass; }
TypePtr DeepCopy() const override { return std::make_shared<MsClassType>(); }
std::string DumpText() const override { return "MsClassType"; }
};
using MsClassTypePtr = std::shared_ptr<MsClassType>;
/// \brief External defines a type which is external. /// \brief External defines a type which is external.
class MS_CORE_API External final : public Type { class MS_CORE_API External final : public Type {
public: public:
@ -360,9 +378,6 @@ class MS_CORE_API External final : public Type {
TypeId generic_type_id() const override { return kMetaTypeExternal; } TypeId generic_type_id() const override { return kMetaTypeExternal; }
TypePtr DeepCopy() const override { return std::make_shared<External>(); } TypePtr DeepCopy() const override { return std::make_shared<External>(); }
std::string DumpText() const override { return "ExternalType"; } std::string DumpText() const override { return "ExternalType"; }
private:
TypePtr kind;
}; };
using ExternalPtr = std::shared_ptr<External>; using ExternalPtr = std::shared_ptr<External>;

View File

@ -939,6 +939,7 @@ GVAR_DEF(PrimitivePtr, kPrimResolve, std::make_shared<Primitive>("resolve"));
GVAR_DEF(PrimitivePtr, kPrimEmbed, std::make_shared<Primitive>("embed")); GVAR_DEF(PrimitivePtr, kPrimEmbed, std::make_shared<Primitive>("embed"));
GVAR_DEF(PrimitivePtr, kPrimRefToEmbed, std::make_shared<Primitive>("RefToEmbed")); GVAR_DEF(PrimitivePtr, kPrimRefToEmbed, std::make_shared<Primitive>("RefToEmbed"));
GVAR_DEF(PrimitivePtr, kPrimCreateInstance, std::make_shared<Primitive>("create_instance")); GVAR_DEF(PrimitivePtr, kPrimCreateInstance, std::make_shared<Primitive>("create_instance"));
GVAR_DEF(PrimitivePtr, kPrimCallInstance, std::make_shared<Primitive>("call_instance"));
// Other miscellaneous // Other miscellaneous
GVAR_DEF(PrimitivePtr, kPrimGetRefOrigin, std::make_shared<Primitive>("get_ref_origin")); GVAR_DEF(PrimitivePtr, kPrimGetRefOrigin, std::make_shared<Primitive>("get_ref_origin"));

View File

@ -31,7 +31,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
"get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary", "get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary",
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs", "ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
"InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
"stop_gradient", "UpdateState", "Load", "Switch", "Print"}; "stop_gradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
#else #else
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem", static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem",
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem", "array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
@ -40,7 +40,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
"identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key", "identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key",
"get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "Debug", "col2im_v1", "get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "Debug", "col2im_v1",
"resolve", "BroadcastGradientArgs", "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", "resolve", "BroadcastGradientArgs", "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
"stop_gradient", "UpdateState", "Load", "Switch", "Print"}; "stop_gradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
#endif #endif
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather, static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather,
prim::kPrimMicroStepAllGather}; prim::kPrimMicroStepAllGather};

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2022 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,20 +18,18 @@ Interfaces for parser module in c++.
from .parser import (Parser, create_instance, is_supported_create_instance_type, generate_scope, from .parser import (Parser, create_instance, is_supported_create_instance_type, generate_scope,
get_bprop_method_of_class, get_class_instance_type, get_class_member_namespace_symbol, get_bprop_method_of_class, get_class_instance_type, get_class_member_namespace_symbol,
create_slice_obj, get_dataclass_attributes, get_dataclass_methods, get_obj_id, create_slice_obj, get_obj_id, get_module_namespace, get_obj_type, get_object_key,
get_module_namespace, get_obj_type, get_object_key, get_ast_type, get_node_type, get_ast_type, get_node_type, get_args, get_args_default_values, get_ast_namespace_symbol,
get_args, get_args_default_values, get_ast_namespace_symbol, get_operation_symbol, get_operation_symbol, get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name,
get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name, eval_script, eval_script, expand_expr_statement, is_class_member, parse_cb, resolve_symbol,
expand_expr_statement, is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor, convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
get_object_description, get_class_attr_namespace_symbol, get_ms_class_name, is_class_type, get_dataclass_attributes, get_dataclass_methods)
get_ms_class_attr)
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', __all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type', 'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
'get_args_default_values', 'get_ast_namespace_symbol', 'get_operation_namespace_symbol', 'create_slice_obj', 'get_obj_id', 'get_module_namespace', 'get_obj_type', 'get_object_key',
'get_args', 'get_obj_type', 'create_instance', 'is_supported_create_instance_type', 'get_ast_type', 'get_node_type', 'get_args', 'get_args_default_values', 'get_ast_namespace_symbol',
'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_operation_symbol', 'get_operation_namespace_symbol', 'get_parse_method_of_class', 'get_scope_name',
'get_dataclass_attributes', 'get_dataclass_methods', 'get_dataclass_methods', 'get_scope_name', 'eval_script', 'expand_expr_statement', 'is_class_member', 'parse_cb', 'resolve_symbol',
'eval_script', 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement', 'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
'generate_scope', 'get_operation_symbol', 'get_class_attr_namespace_symbol', 'get_ms_class_name', 'is_class_type', 'get_dataclass_attributes', 'get_dataclass_methods']
'get_ms_class_attr']

View File

@ -1,6 +1,6 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
# #
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2022 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -118,7 +118,11 @@ class ClassMemberNamespace(Namespace):
except ValueError: except ValueError:
raise UnboundLocalError(name) raise UnboundLocalError(name)
except KeyError: except KeyError:
logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.") # Check if cls is user-defined class decorated with ms_class. If true, an exception will be thrown.
cls = d.__class__
if hasattr(cls, '__ms_class__'):
raise NotImplementedError(f"'{cls.__name__ }' object has no attribute or method: '{name}'.")
logger.info(f"'{cls.__name__ }' object has no attribute or method: '{name}', so will return None.")
raise AttributeError(name) raise AttributeError(name)
@ -142,5 +146,4 @@ class ClassAttrNamespace(Namespace):
except ValueError: except ValueError:
raise UnboundLocalError(name) raise UnboundLocalError(name)
except KeyError: except KeyError:
logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.") raise AttributeError(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}'.")
raise AttributeError(name)

View File

@ -1,6 +1,6 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
# #
# Copyright 2020-2021 Huawei Technologies Co., Ltd # Copyright 2020-2022 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -23,7 +23,6 @@ import hashlib
import inspect import inspect
import types import types
import importlib import importlib
from dataclasses import is_dataclass
from textwrap import dedent from textwrap import dedent
import asttokens import asttokens
@ -324,24 +323,26 @@ def get_class_instance_type(obj):
"""Get the class instance detail type.""" """Get the class instance detail type."""
# check the obj type # check the obj type
logger.debug("Get the class type(%r)", obj) logger.debug("Get the class type(%r)", obj)
class_type = CLASS_INSTANCE_TYPE_INVALID
if _is_class_instance(obj):
if isinstance(obj, nn.Cell): if isinstance(obj, nn.Cell):
class_type = CLASS_INSTANCE_TYPE_CELL return CLASS_INSTANCE_TYPE_CELL
elif isinstance(obj, ops.Primitive): if isinstance(obj, ops.Primitive):
class_type = CLASS_INSTANCE_TYPE_PRIMITIVE return CLASS_INSTANCE_TYPE_PRIMITIVE
# Add the other type base requirement return CLASS_INSTANCE_TYPE_INVALID
return class_type
def _is_class_instance(obj): def _is_ms_class(obj):
"""Confirm the obj is class instance.""" """Check if obj is ms_class object."""
return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_dataclass_instance(obj) return hasattr(obj, '__ms_class__')
def _is_dataclass_instance(obj): def _is_dataclass_instance(obj):
"""Check whether a class is an instance of a dataclass (and not a dataclass itself)""" """Check whether a class is an instance of a dataclass (and not a dataclass itself)"""
return is_dataclass(obj) and not isinstance(obj, type) return hasattr(obj, "__dataclass_fields__") and not isinstance(obj, type)
def _is_class_instance(obj):
"""Confirm the obj is class instance."""
return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_dataclass_instance(obj) or _is_ms_class(obj)
def _convert_tuple_to_args_kwargs(params): def _convert_tuple_to_args_kwargs(params):
@ -358,7 +359,7 @@ def _convert_tuple_to_args_kwargs(params):
def is_supported_create_instance_type(cls_type): def is_supported_create_instance_type(cls_type):
"""Check if cls_type is a supported instance type.""" """Check if cls_type is a supported instance type."""
return issubclass(cls_type, (nn.Cell, ops.Primitive)) return issubclass(cls_type, (nn.Cell, ops.Primitive)) or _is_ms_class(cls_type)
def create_instance(cls_type, params=None): def create_instance(cls_type, params=None):
@ -440,28 +441,19 @@ def get_dataclass_methods(cls):
return methods return methods
def is_class_type(cls):
"""Check if cls is a class type."""
return isinstance(cls, type)
def get_ms_class_name(cls): def get_ms_class_name(cls):
"""Get the name of the class instance decorated by ms_class.""" """Get the name of the class instance decorated by ms_class."""
# Check if cls is nn.Cell. # Check if cls is nn.Cell.
if isinstance(cls, nn.Cell): if isinstance(cls, nn.Cell):
raise TypeError(f"ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.") raise TypeError(f"ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
if isinstance(cls, type): if isinstance(cls, type):
name = cls.__name__ return cls.__name__
else: return cls.__class__.__name__
name = cls.__class__.__name__
# Get the name of cls.
cls_name = cls.__module__ + '.' + name
return cls_name
def get_ms_class_attr(cls, name: str):
"""Get attribute or method of ms_class obj."""
# Don't take into account python magic methods and private variables.
if name.startswith('_'):
raise AttributeError(f"{name} is a private variable or magic method, which is not supported.")
if not hasattr(cls, name):
raise AttributeError(f"{cls} has no attribute: {name}.")
return getattr(cls, name)
def convert_to_ms_tensor(data): def convert_to_ms_tensor(data):

View File

@ -1,6 +1,6 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
# #
# Copyright 2020-2021 Huawei Technologies Co., Ltd # Copyright 2020-2022 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,10 +15,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""standard_method""" """standard_method"""
from mindspore import Tensor, Parameter, CSRTensor, COOTensor, ms_class
from dataclasses import dataclass
from mindspore import Tensor, Parameter, CSRTensor, COOTensor
from mindspore import dtype as mstype from mindspore import dtype as mstype
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
@ -1828,16 +1825,16 @@ def float_floordiv(x, y):
############# #############
@dataclass(frozen=True) @ms_class
class SequenceIterator: class SequenceIterator:
""" """
SequenceIterator is a util dataclass for iterating sequence object. SequenceIterator is a util class for iterating sequence object.
Iterator to use for sequences like List, Array. Iterator to use for sequences like List, Array.
""" """
def __init__(self, idx, seq):
idx: int self.idx = idx
seq: list self.seq = seq
@core(ignore_values=True) @core(ignore_values=True)
def __ms_hasnext__(self): def __ms_hasnext__(self):

View File

@ -0,0 +1,451 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test graph fallback """
import pytest
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor, context, ms_class
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_attr():
"""
Feature: JIT Fallback
Description: Access the attributes of user-defined classes decorated by ms_class.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.number = Tensor(1, dtype=mstype.int32)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self):
out = self.inner_net.number
return out
net = Net()
out = net()
assert out.asnumpy() == 1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_method():
"""
Feature: JIT Fallback
Description: Access the methods of user-defined classes decorated by ms_class.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.val = Tensor(2, dtype=mstype.int32)
def act(self, x, y):
return self.val * (x + y)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self, x, y):
out = self.inner_net.act(x, y)
return out
x = Tensor(2, dtype=mstype.int32)
y = Tensor(3, dtype=mstype.int32)
net = Net()
out = net(x, y)
assert out.asnumpy() == 10
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_call():
"""
Feature: JIT Fallback
Description: Call the __call__ function of user-defined classes decorated by ms_class.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self, val):
self.val = val
def __call__(self, x, y):
return self.val * (x + y)
class Net(nn.Cell):
def __init__(self, val):
super(Net, self).__init__()
self.inner_net = InnerNet(val)
def construct(self, x, y):
out = self.inner_net(x, y)
return out
val = Tensor(2, dtype=mstype.int32)
x = Tensor(3, dtype=mstype.int32)
y = Tensor(4, dtype=mstype.int32)
net = Net(val)
out = net(x, y)
assert out.asnumpy() == 14
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_input_attr():
"""
Feature: JIT Fallback
Description: Access the attributes of user-defined classes decorated by ms_class.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.number = Tensor(np.array([1, 2, 3]))
class Net(nn.Cell):
def __init__(self, net):
super(Net, self).__init__()
self.inner_net = net()
def construct(self):
out = self.inner_net.number
return out
net = Net(InnerNet)
out = net()
expect_res = np.array([1, 2, 3])
assert np.all(out.asnumpy() == expect_res)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_input_method():
"""
Feature: JIT Fallback
Description: Access the methods of user-defined classes decorated by ms_class.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.val = Tensor(2, dtype=mstype.int32)
def act(self, x, y):
return self.val * (x + y)
class Net(nn.Cell):
def __init__(self, net):
super(Net, self).__init__()
self.inner_net = net()
def construct(self):
out = self.inner_net.act(1, 2)
return out
net = Net(InnerNet)
out = net()
assert out.asnumpy() == 6
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_class_nested():
"""
Feature: JIT Fallback
Description: Test nested ms_class in graph.
Expectation: No exception.
"""
@ms_class
class Inner:
def __init__(self):
self.number = Tensor(1, dtype=mstype.int32)
@ms_class
class InnerNet:
def __init__(self):
self.inner = Inner()
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self):
out = self.inner_net.inner.number
return out
net = Net()
out = net()
assert out.asnumpy() == 1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_cell_nested():
"""
Feature: JIT Fallback
Description: Test nested ms_class and cell in graph.
Expectation: No exception.
"""
class Net(nn.Cell):
def __init__(self, val):
super().__init__()
self.val = val
def construct(self, x):
return x + self.val
@ms_class
class TrainNet():
class Loss(nn.Cell):
def __init__(self, net):
super().__init__()
self.net = net
def construct(self, x):
out = self.net(x)
return out * 2
def __init__(self, net):
self.net = net
loss_net = self.Loss(self.net)
self.number = loss_net(10)
global_net = Net(1)
class LearnNet(nn.Cell):
def __init__(self):
super().__init__()
self.value = TrainNet(global_net).number
def construct(self, x):
return x + self.value
leanrn_net = LearnNet()
out = leanrn_net(3)
print(out)
assert out == 25
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_type_attr():
"""
Feature: JIT Fallback
Description: Access the attributes of class type.
Expectation: No exception.
"""
@ms_class
class InnerNet:
val = Tensor(2, dtype=mstype.int32)
def act(self, x, y):
return self.val * (x + y)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet
# Support accessing attributes of class type, but do not support
# accessing methods, e.g. self.inner_net.act(1, 2)
def construct(self):
out = self.inner_net.val
return out
net = Net()
out = net()
assert out == 2
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_create_instance_attr():
"""
Feature: JIT Fallback
Description: Access the attributes of the created class instance.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self, val):
self.number = val + 3
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet
def construct(self, x):
net = self.inner_net(x)
return net.number
net = Net()
out = net(2)
assert out == 5
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_create_instance_method():
"""
Feature: JIT Fallback
Description: Access the methods of the created class instance.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self, val):
self.number = val
def act(self, x, y):
return self.number * (x + y)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet
def construct(self, x, y, z):
net = self.inner_net(x)
return net.act(y, z)
x = 2
y = Tensor(2, dtype=mstype.int32)
z = Tensor(3, dtype=mstype.int32)
net = Net()
out = net(x, y, z)
assert out.asnumpy() == 10
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_class_create_instance_call():
"""
Feature: JIT Fallback
Description: Call the __call__ function of the created class instance.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self, number):
self.number = number
def __call__(self, x, y):
return self.number * (x + y)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet
def construct(self, x, y, z):
net = self.inner_net(x)
out = net(y, z)
return out
x = 2
y = Tensor(2, dtype=mstype.int32)
z = Tensor(3, dtype=mstype.int32)
net = Net()
out = net(x, y, z)
assert out == 10
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_raise_error_not_class_type():
"""
Feature: JIT Fallback
Description: Decorator ms_class cannot be used for non-class types.
Expectation: No exception.
"""
with pytest.raises(TypeError):
@ms_class
def func(x, y):
return x + y
func(1, 2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_raise_error_decorate_cell():
"""
Feature: JIT Fallback
Description: Decorator ms_class cannot be used for nn.Cell
Expectation: No exception.
"""
@ms_class
class Net(nn.Cell):
def construct(self, x):
return x
with pytest.raises(TypeError):
x = Tensor(1)
net = Net()
net(x)

View File

@ -1,416 +0,0 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test graph fallback """
import pytest
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor, context, ms_class, ms_function
from . import test_graph_fallback
context.set_context(mode=context.GRAPH_MODE)
def test_fallback_self_attr():
"""
Feature: JIT Fallback
Description: Test self.attr in graph.
Expectation: No exception.
"""
class Network(nn.Cell):
def __init__(self):
super(Network, self).__init__()
self.dim = 1
def construct(self, x):
batch = x.shape[0]
one = Tensor(np.ones([batch, self.dim]), mstype.float32)
return one * x
net = Network()
x = Tensor([1, 2], mstype.float32)
out = net(x)
expect = np.array([[1., 2.], [1., 2.]])
assert np.allclose(out.asnumpy(), expect, 1.e-2, 1.e-2)
def test_fallback_self_attr_fn():
"""
Feature: JIT Fallback
Description: Test self.attr in graph.
Expectation: No exception.
"""
class Network(nn.Cell):
def __init__(self, fn):
super(Network, self).__init__()
self.fn = fn
def construct(self):
x = np.array([1, 2, 3])
y = np.array([3, 4, 5])
out = Tensor(self.fn(x, y))
return out
def fn(x, y):
return x + y
net = Network(fn)
out = net()
expect = np.array([4, 6, 8])
assert np.all(out.asnumpy() == expect)
def test_fallback_self_attr_attr():
"""
Feature: JIT Fallback
Description: Test self.attr in graph.
Expectation: No exception.
"""
class Network(nn.Cell):
def __init__(self):
super(Network, self).__init__()
self.value = [2, 2, 3]
def construct(self):
x = np.array(self.value.count(2))
return Tensor(x)
net = Network()
out = net()
assert out == 2
def test_fallback_self_method():
"""
Feature: JIT Fallback
Description: Test self.method in graph.
Expectation: No exception.
"""
class Network(nn.Cell):
def construct(self):
x = np.array([1, 2, 3])
y = np.array([3, 4, 5])
out = Tensor(self.fn(x, y))
return out
def fn(self, x, y):
return x + y
net = Network()
out = net()
expect = np.array([4, 6, 8])
assert np.all(out.asnumpy() == expect)
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
def test_fallback_self_method_tensor():
"""
Feature: JIT Fallback
Description: Test self.method in graph.
Expectation: No exception.
"""
class Network(nn.Cell):
def construct(self):
x = np.array([1, 2, 3])
y = np.array([3, 4, 5])
z = self.fn(x, y)
out = Tensor(z)
return out
def fn(self, x, y):
return x + y
net = Network()
out = net()
print(out)
def test_fallback_import_modules():
"""
Feature: JIT Fallback
Description: add_func is defined in test_graph_fallback.py
Expectation: No exception.
"""
@ms_function
def use_imported_module(x, y):
out = test_graph_fallback.add_func(x, y)
return out
x = Tensor(2, dtype=mstype.int32)
y = Tensor(3, dtype=mstype.int32)
out = use_imported_module(x, y)
print(out)
def test_fallback_class_attr():
"""
Feature: JIT Fallback
Description: Test user-defined class attributes in graph.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.number = 1
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self):
out = self.inner_net.number
return out
net = Net()
out = net()
assert out == 1
def test_fallback_class_method():
"""
Feature: JIT Fallback
Description: Test user-defined class methods in graph.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.val = 2
def act(self, x, y):
return self.val * (x + y)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self):
out = self.inner_net.act(1, 2)
return out
net = Net()
out = net()
assert out == 6
def test_fallback_class_input_attr():
"""
Feature: JIT Fallback
Description: Test user-defined class attributes in graph.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.number = Tensor(np.array([1, 2, 3]))
class Net(nn.Cell):
def __init__(self, net):
super(Net, self).__init__()
self.inner_net = net()
def construct(self):
out = self.inner_net.number
return out
net = Net(InnerNet)
out = net()
expect_res = np.array([1, 2, 3])
assert np.all(out.asnumpy() == expect_res)
def test_fallback_class_input_method():
"""
Feature: JIT Fallback
Description: Test user-defined class methods in graph.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.val = 2
def act(self, x, y):
return self.val * (x + y)
class Net(nn.Cell):
def __init__(self, net):
super(Net, self).__init__()
self.inner_net = net()
def construct(self):
out = self.inner_net.act(1, 2)
return out
net = Net(InnerNet)
out = net()
assert out == 6
def test_fallback_class_class_nested():
"""
Feature: JIT Fallback
Description: Test nested ms_class in graph.
Expectation: No exception.
"""
@ms_class
class Inner:
def __init__(self):
self.number = 1
@ms_class
class InnerNet:
def __init__(self):
self.inner = Inner()
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self):
out = self.inner_net.inner.number
return out
net = Net()
out = net()
assert out == 1
def test_fallback_class_cell_nested():
"""
Feature: JIT Fallback
Description: Test nested ms_class and cell in graph.
Expectation: No exception.
"""
class Net(nn.Cell):
def __init__(self, val):
super().__init__()
self.val = val
def construct(self, x):
return x + self.val
@ms_class
class TrainNet():
class Loss(nn.Cell):
def __init__(self, net):
super().__init__()
self.net = net
def construct(self, x):
out = self.net(x)
return out * 2
def __init__(self, net):
self.net = net
loss_net = self.Loss(self.net)
self.number = loss_net(10)
global_net = Net(1)
class LearnNet(nn.Cell):
def __init__(self):
super().__init__()
self.value = TrainNet(global_net).number
def construct(self, x):
return x + self.value
leanrn_net = LearnNet()
out = leanrn_net(3)
print(out)
assert out == 25
@pytest.mark.skip(reason='Not support in graph yet')
def test_fallback_class_isinstance():
"""
Feature: JIT Fallback
Description: Test ms_class in graph.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.number = 1
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.inner_net = InnerNet()
def construct(self, x):
if isinstance(self.inner_net, InnerNet):
return x + 10
return x
net = Net()
out = net(5)
assert out == 15
def test_fallback_raise_error_not_class_type():
"""
Feature: JIT Fallback
Description: Test ms_class in graph.
Expectation: No exception.
"""
with pytest.raises(TypeError):
@ms_class
def func(x, y):
return x + y
func(1, 2)
def test_fallback_raise_error_not_class_instance():
"""
Feature: JIT Fallback
Description: Test ms_class in graph.
Expectation: No exception.
"""
@ms_class
class InnerNet:
def __init__(self):
self.number = 1
class Net(nn.Cell):
def construct(self):
out = InnerNet().number
return out
with pytest.raises(ValueError):
net = Net()
net()
def test_fallback_raise_error_decorate_cell():
"""
Feature: JIT Fallback
Description: Test ms_class in graph.
Expectation: No exception.
"""
@ms_class
class Net(nn.Cell):
def construct(self, x):
return x
with pytest.raises(TypeError):
x = Tensor(1)
net = Net()
net(x)

View File

@ -0,0 +1,131 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test graph fallback """
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor, context, ms_function
from . import test_graph_fallback
context.set_context(mode=context.GRAPH_MODE)
def test_fallback_self_attr():
"""
Feature: JIT Fallback
Description: Use self.attr in expressions supported by JIT Fallback.
Expectation: No exception.
"""
class Network(nn.Cell):
def __init__(self):
super(Network, self).__init__()
self.dim = 1
def construct(self, x):
batch = x.shape[0]
one = Tensor(np.ones([batch, self.dim]), mstype.float32)
return one * x
net = Network()
x = Tensor([1, 2], mstype.float32)
out = net(x)
expect = np.array([[1., 2.], [1., 2.]])
assert np.allclose(out.asnumpy(), expect, 1.e-2, 1.e-2)
def test_fallback_self_attr_fn():
"""
Feature: JIT Fallback
Description: Use self.attr of type function in expressions supported by JIT Fallback.
Expectation: No exception.
"""
class Network(nn.Cell):
def __init__(self, fn):
super(Network, self).__init__()
self.fn = fn
def construct(self):
x = np.array([1, 2, 3])
y = np.array([3, 4, 5])
out = Tensor(self.fn(x, y))
return out
def fn(x, y):
return x + y
net = Network(fn)
out = net()
expect = np.array([4, 6, 8])
assert np.all(out.asnumpy() == expect)
def test_fallback_self_attr_attr():
"""
Feature: JIT Fallback
Description: In expressions supported by JIT Fallback, use the attribute of self.attr.
Expectation: No exception.
"""
class Network(nn.Cell):
def __init__(self):
super(Network, self).__init__()
self.value = [2, 2, 3]
def construct(self):
x = np.array(self.value.count(2))
return Tensor(x)
net = Network()
out = net()
assert out == 2
def test_fallback_self_method():
"""
Feature: JIT Fallback
Description: Use self.method in expressions supported by JIT Fallback.
Expectation: No exception.
"""
class Network(nn.Cell):
def construct(self):
x = np.array([1, 2, 3])
y = np.array([3, 4, 5])
out = Tensor(self.fn(x, y))
return out
def fn(self, x, y):
return x + y
net = Network()
out = net()
expect = np.array([4, 6, 8])
assert np.all(out.asnumpy() == expect)
def test_fallback_import_modules():
"""
Feature: JIT Fallback
Description: Check whether the call to the third-party library is correct. It has nothing to do with class.
Expectation: No exception.
"""
@ms_function
def use_imported_module(x, y):
out = test_graph_fallback.add_func(x, y)
return out
x = Tensor(2, dtype=mstype.int32)
y = Tensor(3, dtype=mstype.int32)
out = use_imported_module(x, y)
print(out)