diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index 6da9e08559c..98d3ee0002b 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -658,13 +658,19 @@ ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) { return class_type; } -// Check the object is Cell Instance. +// Check if the object is Cell instance. bool IsCellInstance(const py::object &obj) { auto class_type = GetClassInstanceType(obj); bool is_cell = (class_type == CLASS_INSTANCE_TYPE_CELL); return is_cell; } +// Check if the object is class type. +bool IsClassType(const py::object &obj) { + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + return python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_CLASS_TYPE, obj).cast(); +} + // Create the python class instance. py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.h b/mindspore/ccsrc/pipeline/jit/parse/data_converter.h index 9c7828d1fb3..5bcf13e7125 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.h +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.h @@ -1,7 +1,7 @@ /** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * - * Copyright 2019-2021 Huawei Technologies Co., Ltd + * Copyright 2019-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,6 +44,7 @@ ResolveTypeDef GetObjType(const py::object &obj); ClassInstanceTypeDef GetClassInstanceType(const py::object &obj); bool IsCellInstance(const py::object &obj); +bool IsClassType(const py::object &obj); py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs); py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs); void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name); diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h index 5914b46d202..edff068fe4b 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h @@ -67,8 +67,8 @@ const char PYTHON_MOD_CREATE_INSTANCE[] = "create_instance"; const char PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE[] = "is_supported_create_instance_type"; const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes"; const char PYTHON_MOD_GET_DATACLASS_METHODS[] = "get_dataclass_methods"; +const char PYTHON_MOD_IS_CLASS_TYPE[] = "is_class_type"; const char PYTHON_MOD_GET_MS_CLASS_NAME[] = "get_ms_class_name"; -const char PYTHON_MOD_GET_MS_CLASS_ATTR[] = "get_ms_class_attr"; const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace"; const char PYTHON_MOD_GET_ATTR_NAMESPACE_SYMBOL[] = "get_class_attr_namespace_symbol"; const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol"; diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index 8469ddf4bd4..1141a2a695a 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -67,6 +67,7 @@ struct AnfDumpHandlerRegister { } } callback_register; } // namespace + abstract::AbstractBasePtr ClassObject::ToAbstract() { ClassPtr cls_ptr = ParseDataClass(obj()); auto abs_scalar = std::make_shared(); @@ -78,6 +79,24 @@ abstract::AbstractBasePtr ClassObject::ToAbstract() { return std::make_shared(func_ptr, args_spec_list); } +abstract::AbstractBasePtr MsClassObject::ToAbstract() { + auto abs_scalar = + std::make_shared(shared_from_base(), std::make_shared()); + AbstractBasePtrList args_spec_list = {abs_scalar}; + abstract::PrimitiveAbstractClosurePtr func_ptr = nullptr; + bool is_class_type = parse::data_converter::IsClassType(obj()); + if (is_class_type) { + // Class type as func, such as Net(x, y) + func_ptr = std::make_shared(prim::kPrimCreateInstance); + } else { + // Class instance as func, such as net(x, y) + func_ptr = std::make_shared(prim::kPrimCallInstance); + } + auto ret_val = std::make_shared(func_ptr, args_spec_list); + ret_val->set_value_desc(ToString()); + return ret_val; +} + static inline bool IsSupportedCreateInstanceType(const py::object &obj) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); auto res = python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE, obj); @@ -520,14 +539,15 @@ AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const MsCl MS_LOG(DEBUG) << "Resolve ms_class obj (" << ms_class->name() << ") with attr " << attr << "."; TraceGuard trace_guard(std::make_shared(node->debug_info())); + constexpr size_t prefix_index = 0; + if (attr.size() > 0 && attr[prefix_index] == '_') { + MS_LOG(EXCEPTION) << attr << " is a private variable or magic method, which is not supported."; + } py::object cls_obj = ms_class->obj(); - if (!py::hasattr(cls_obj, attr.c_str())) { + if (!py::hasattr(cls_obj, common::SafeCStr(attr))) { MS_LOG(EXCEPTION) << ms_class->name() << " has not attribute: " << attr << "."; } - - const std::string fn = PYTHON_MOD_GET_MS_CLASS_ATTR; - const std::string module = "mindspore._extends.parse.parser"; - py::object attr_obj = python_adapter::GetPyFn(module, fn)(cls_obj, attr); + py::object attr_obj = py::getattr(cls_obj, common::SafeCStr(attr)); AnfNodePtr res_node = ResolveObjectAndAddToManager(manager, attr_obj, node); TraceManager::ClearParseOrResolveDebugInfo(); return res_node; diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.h b/mindspore/ccsrc/pipeline/jit/parse/resolve.h index 122c6724cc3..43fa2a4750f 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.h +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.h @@ -116,6 +116,7 @@ class PyObjectWrapper : public Named { // the object that needs to be resolved py::object obj_; }; +using PyObjectWrapperPtr = std::shared_ptr; // InterpretedObject class wrappers interpreted python object. class InterpretedObject final : public PyObjectWrapper { @@ -137,9 +138,7 @@ class MsClassObject final : public PyObjectWrapper { : PyObjectWrapper(obj, "MsClassObject: \'" + name + "\'") {} ~MsClassObject() override = default; MS_DECLARE_PARENT(MsClassObject, PyObjectWrapper); - abstract::AbstractBasePtr ToAbstract() override { - return std::make_shared(shared_from_base(), std::make_shared()); - } + abstract::AbstractBasePtr ToAbstract() override; }; using MsClassObjectPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 7519dc7cf09..8b54ff35529 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -1304,6 +1304,35 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng return StaticGetterInferred(converted_value, data_conf, out_conf); } +EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value, + const ValuePtr &data_value, const ConfigPtr &data_conf, + const AnfNodeConfigPtr &out_conf) { + MS_EXCEPTION_IF_NULL(item_value); + MS_EXCEPTION_IF_NULL(data_value); + // Get the name of item. + if (!item_value->isa()) { + MS_LOG(EXCEPTION) << "Expect a string, but got: " << item_value->ToString(); + } + std::string item_name = item_value->cast()->value(); + // Get ms_class object. + if (!data_value->isa()) { + MS_LOG(EXCEPTION) << "Expect a ms_class object, but got " << data_value->ToString(); + } + auto ms_class = data_value->cast(); + MS_LOG(DEBUG) << "Resolve ms_class (" << ms_class->name() << ") with item " << item_name << "."; + + // Get the attr/method of ms_class object. + auto out_node = out_conf->node(); + FuncGraphPtr func_graph = out_node->func_graph(); + auto new_node = ResolveMsClassWithAttr(func_graph->manager(), ms_class, item_name, out_node); + // Replace old node with the resolved new node in order list. + func_graph->ReplaceInOrder(out_node, new_node); + AnalysisEnginePtr eng = out_conf->engine(); + MS_EXCEPTION_IF_NULL(eng); + AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph()); + return eng->ForwardConfig(out_conf, fn_conf); +} + EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value, const TypePtr &data_type, const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { @@ -1363,17 +1392,45 @@ int64_t GetResolveType(const TypePtr &data_type) { return kResolveTypeFunction; } +ValuePtr GetMsClassObject(const AbstractBasePtr &abs) { + if (!abs->isa()) { + return nullptr; + } + auto partial_abs = abs->cast(); + auto fn = partial_abs->fn(); + if (!fn->isa()) { + return nullptr; + } + // Check if type is kObjectTypeClass. + auto args = partial_abs->args(); + if (args.size() > 0) { + constexpr size_t first_input_index = 0; + auto first_arg = args[first_input_index]; + MS_EXCEPTION_IF_NULL(first_arg); + auto type = first_arg->BuildType(); + MS_EXCEPTION_IF_NULL(type); + if (type->type_id() == kObjectTypeClass) { + return first_arg->BuildValue(); + } + } + return nullptr; +} + EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { // Inputs: namespace and its static function; or class and its member function CheckArgsSize("StaticGetter", args_spec_list, 2); - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - MS_EXCEPTION_IF_NULL(args_spec_list[1]); - MS_LOG(DEBUG) << "Args[0]: " << args_spec_list[0]->ToString(); - MS_LOG(DEBUG) << "Args[1]: " << args_spec_list[1]->ToString(); - TypePtr data_type = args_spec_list[0]->BuildType(); - ValuePtr item_value = args_spec_list[1]->BuildValue(); + constexpr size_t data_index = 0; + constexpr size_t item_index = 1; + auto data_args = args_spec_list[data_index]; + auto item_args = args_spec_list[item_index]; + MS_EXCEPTION_IF_NULL(data_args); + MS_EXCEPTION_IF_NULL(item_args); + MS_LOG(DEBUG) << "StaticGetter, data: " << data_args->ToString() << ", item: " << item_args->ToString(); + TypePtr data_type = data_args->BuildType(); + ValuePtr item_value = item_args->BuildValue(); + ScopePtr scope = kDefaultScope; if (out_conf != nullptr) { scope = out_conf->node()->scope(); @@ -1384,6 +1441,10 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString(); } + auto class_value = GetMsClassObject(data_args); + if (class_value != nullptr) { + return GetEvaluatedValueForMsClassAttrOrMethod(engine, item_value, class_value, data_conf, out_conf); + } int64_t resolve_type = GetResolveType(data_type); if (resolve_type == kResolveTypeUserDefineClass) { return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf); @@ -1581,46 +1642,47 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator { MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator); EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &, const AnfNodeConfigPtr &out_conf) override { + // Check the type parameter. if (args_spec_list.empty()) { MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty"; } - - // Get the type parameter. - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - TypePtr type = args_spec_list[0]->GetTypeTrack(); + constexpr size_t type_index = 0; + auto arg_class_type = args_spec_list[type_index]; + MS_EXCEPTION_IF_NULL(arg_class_type); + TypePtr type = arg_class_type->GetTypeTrack(); MS_EXCEPTION_IF_NULL(type); - if (type->type_id() != kMetaTypeTypeType) { - MS_LOG(EXCEPTION) << "CreateInstanceEvaluator require first parameter should be an object of TypeType, but got " - << type->ToString(); + if (type->type_id() != kMetaTypeTypeType && type->type_id() != kObjectTypeClass) { + MS_LOG(EXCEPTION) + << "CreateInstanceEvaluator require first parameter should be an object of TypeType or TypeClass, but got " + << type->ToString(); } - ValuePtr value_track = args_spec_list[0]->GetValueTrack(); + ValuePtr value_track = arg_class_type->GetValueTrack(); MS_EXCEPTION_IF_NULL(value_track); - - std::shared_ptr type_obj = dyn_cast(value_track); + parse::PyObjectWrapperPtr type_obj = dyn_cast(value_track); if (type_obj == nullptr) { MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << "."; } - - if (!type_obj->isa()) { - MS_LOG(EXCEPTION) << "CreateInstanceEvaluator the type_obj should be an object of ClassType, but got " - << type_obj->ToString() << "."; + if (!type_obj->isa() && !type_obj->isa()) { + MS_LOG(EXCEPTION) + << "CreateInstanceEvaluator the type_obj should be an object of ClassType or MsClassObject, but got " + << type_obj->ToString() << "."; } - auto class_type = type_obj->obj(); - MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << "."; + MS_LOG(DEBUG) << "Get class type: " << type_obj->ToString() << "."; // Get the create instance obj's parameters, `params` may contain tuple(args, kwargs). py::tuple params = GetParameters(args_spec_list); - // Create class instance. auto obj = parse::data_converter::CreatePythonObject(class_type, params); if (py::isinstance(obj)) { MS_LOG(EXCEPTION) << "Create python object `" << py::str(class_type) - << "` failed, only support to create \'Cell\' or \'Primitive\' object."; + << "` failed, only support to create \'Cell\', \'Primitive\' or " + << "user-defined Class decorated with \'ms_class\'."; } // Process the object. + MS_EXCEPTION_IF_NULL(out_conf->node()); TraceGuard guard(std::make_shared(out_conf->node()->debug_info())); ValuePtr converted_ret = nullptr; bool converted = parse::ConvertData(obj, &converted_ret, true); @@ -1628,7 +1690,6 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator { MS_LOG(EXCEPTION) << "Convert the python object failed"; } MS_EXCEPTION_IF_NULL(converted_ret); - if (converted_ret->isa()) { AddToManager(engine, converted_ret->cast()); } @@ -1664,6 +1725,63 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator { } }; +class CallInstanceEvaluator : public TransitionPrimEvaluator { + public: + CallInstanceEvaluator() : TransitionPrimEvaluator("CallInstanceEvaluator") {} + ~CallInstanceEvaluator() override = default; + MS_DECLARE_PARENT(CallInstanceEvaluator, TransitionPrimEvaluator); + EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &, + const AnfNodeConfigPtr &out_conf) override { + if (args_spec_list.empty()) { + MS_LOG(EXCEPTION) << "args_spec_list should not be empty."; + } + constexpr size_t cls_index = 0; + auto arg_cls = args_spec_list[cls_index]; + MS_EXCEPTION_IF_NULL(arg_cls); + TypePtr type = arg_cls->GetTypeTrack(); + MS_EXCEPTION_IF_NULL(type); + if (type->type_id() != kObjectTypeClass) { + MS_LOG(EXCEPTION) << "CallInstanceEvaluator require first parameter should be an object of TypeClass, but got " + << type->ToString(); + } + ValuePtr value_track = arg_cls->GetValueTrack(); + MS_EXCEPTION_IF_NULL(value_track); + parse::MsClassObjectPtr ms_class = dyn_cast(value_track); + if (ms_class == nullptr) { + MS_LOG(EXCEPTION) << "CallInstanceEvaluator only supports MsClassObject."; + } + + // Call class instance, net(x, y) -> net.__call__(x, y) + py::object cls_obj = ms_class->obj(); + const std::string call_func = "__call__"; + if (!py::hasattr(cls_obj, common::SafeCStr(call_func))) { + MS_LOG(EXCEPTION) << ms_class->name() << " has no " << call_func << " function, please check the code."; + } + py::object call_obj = py::getattr(cls_obj, common::SafeCStr(call_func)); + FuncGraphPtr call_func_graph = parse::ConvertToFuncGraph(call_obj); + if (call_func_graph == nullptr) { + MS_LOG(EXCEPTION) << "Parse python object " << call_func << " failed."; + } + FuncGraphManagerPtr manager = engine->func_graph_manager(); + manager->AddFuncGraph(call_func_graph); + + // Replace net with net.__call__ + AnfNodePtr old_node = out_conf->node(); + MS_EXCEPTION_IF_NULL(old_node); + CNodePtr old_cnode = dyn_cast(old_node); + MS_EXCEPTION_IF_NULL(old_cnode); + std::vector inputs = {NewValueNode(call_func_graph)}; + for (size_t i = 1; i < old_cnode->size(); i++) { + (void)inputs.emplace_back(old_cnode->input(i)); + } + FuncGraphPtr func_graph = out_conf->func_graph(); + auto new_cnode = func_graph->NewCNode(inputs); + // Continue to eval new_cnode. + AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph()); + return engine->ForwardConfig(out_conf, fn_conf); + } +}; + class PyInterpretEvaluator : public TransitionPrimEvaluator { public: PyInterpretEvaluator() : TransitionPrimEvaluator("PyInterpretEvaluator") {} @@ -2085,6 +2203,7 @@ void InitPrimEvaluatorConstructors() { constructor[prim::kPrimGetAttr] = std::make_shared(); constructor[prim::kPrimResolve] = std::make_shared(); constructor[prim::kPrimCreateInstance] = std::make_shared(); + constructor[prim::kPrimCallInstance] = std::make_shared(); constructor[prim::kPrimPartial] = std::make_shared(); constructor[prim::kPrimPyInterpret] = std::make_shared(); constructor[prim::kPrimMakeTuple] = std::make_shared(); diff --git a/mindspore/ccsrc/pipeline/jit/validator.cc b/mindspore/ccsrc/pipeline/jit/validator.cc index 86268a09a12..0ee01dcb5d4 100644 --- a/mindspore/ccsrc/pipeline/jit/validator.cc +++ b/mindspore/ccsrc/pipeline/jit/validator.cc @@ -82,7 +82,7 @@ bool CheckAbstractScalar(const AnfNodePtr &node) { if (abstract->isa()) { TypePtr type = abstract->GetTypeTrack(); MS_EXCEPTION_IF_NULL(type); - if (type->isa()) { + if (type->isa() || type->isa()) { MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString() << ", node: " << node->DebugString(); } if (type->isa() || type->isa()) { diff --git a/mindspore/core/ir/dtype.h b/mindspore/core/ir/dtype.h index 31cf697879a..73806568f89 100644 --- a/mindspore/core/ir/dtype.h +++ b/mindspore/core/ir/dtype.h @@ -345,6 +345,24 @@ class MS_CORE_API Problem final : public Type { }; using ProblemPtr = std::shared_ptr; +/// \brief MsClassType defines a type which is ms_class. +class MS_CORE_API MsClassType final : public Type { + public: + /// \brief The constructor of External. + /// + /// \return The instance of External. + MsClassType() : Type(kObjectTypeClass) {} + + /// \brief The destructor of External. + ~MsClassType() override = default; + MS_DECLARE_PARENT(MsClassType, Type) + + TypeId generic_type_id() const override { return kObjectTypeClass; } + TypePtr DeepCopy() const override { return std::make_shared(); } + std::string DumpText() const override { return "MsClassType"; } +}; +using MsClassTypePtr = std::shared_ptr; + /// \brief External defines a type which is external. class MS_CORE_API External final : public Type { public: @@ -360,9 +378,6 @@ class MS_CORE_API External final : public Type { TypeId generic_type_id() const override { return kMetaTypeExternal; } TypePtr DeepCopy() const override { return std::make_shared(); } std::string DumpText() const override { return "ExternalType"; } - - private: - TypePtr kind; }; using ExternalPtr = std::shared_ptr; diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index 0deb5f1f7c8..bc4094c4d4f 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -939,6 +939,7 @@ GVAR_DEF(PrimitivePtr, kPrimResolve, std::make_shared("resolve")); GVAR_DEF(PrimitivePtr, kPrimEmbed, std::make_shared("embed")); GVAR_DEF(PrimitivePtr, kPrimRefToEmbed, std::make_shared("RefToEmbed")); GVAR_DEF(PrimitivePtr, kPrimCreateInstance, std::make_shared("create_instance")); +GVAR_DEF(PrimitivePtr, kPrimCallInstance, std::make_shared("call_instance")); // Other miscellaneous GVAR_DEF(PrimitivePtr, kPrimGetRefOrigin, std::make_shared("get_ref_origin")); diff --git a/mindspore/core/utils/parallel_node_check.cc b/mindspore/core/utils/parallel_node_check.cc index 132d81ad201..89b071f4200 100644 --- a/mindspore/core/utils/parallel_node_check.cc +++ b/mindspore/core/utils/parallel_node_check.cc @@ -31,7 +31,7 @@ static const std::set PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary", "ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs", "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", - "stop_gradient", "UpdateState", "Load", "Switch", "Print"}; + "stop_gradient", "UpdateState", "Load", "Switch", "Print", "call_instance"}; #else static const std::set PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem", "array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem", @@ -40,7 +40,7 @@ static const std::set PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key", "get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "Debug", "col2im_v1", "resolve", "BroadcastGradientArgs", "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", - "stop_gradient", "UpdateState", "Load", "Switch", "Print"}; + "stop_gradient", "UpdateState", "Load", "Switch", "Print", "call_instance"}; #endif static const std::set ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather, prim::kPrimMicroStepAllGather}; diff --git a/mindspore/python/mindspore/_extends/parse/__init__.py b/mindspore/python/mindspore/_extends/parse/__init__.py index 99b0a40e7fe..9d41cb8f182 100644 --- a/mindspore/python/mindspore/_extends/parse/__init__.py +++ b/mindspore/python/mindspore/_extends/parse/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,20 +18,18 @@ Interfaces for parser module in c++. from .parser import (Parser, create_instance, is_supported_create_instance_type, generate_scope, get_bprop_method_of_class, get_class_instance_type, get_class_member_namespace_symbol, - create_slice_obj, get_dataclass_attributes, get_dataclass_methods, get_obj_id, - get_module_namespace, get_obj_type, get_object_key, get_ast_type, get_node_type, - get_args, get_args_default_values, get_ast_namespace_symbol, get_operation_symbol, - get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name, eval_script, - expand_expr_statement, is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor, - get_object_description, get_class_attr_namespace_symbol, get_ms_class_name, - get_ms_class_attr) + create_slice_obj, get_obj_id, get_module_namespace, get_obj_type, get_object_key, + get_ast_type, get_node_type, get_args, get_args_default_values, get_ast_namespace_symbol, + get_operation_symbol, get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name, + eval_script, expand_expr_statement, is_class_member, parse_cb, resolve_symbol, + convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name, + is_class_type, get_dataclass_attributes, get_dataclass_methods) -__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', - 'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type', - 'get_args_default_values', 'get_ast_namespace_symbol', 'get_operation_namespace_symbol', - 'get_args', 'get_obj_type', 'create_instance', 'is_supported_create_instance_type', - 'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', - 'get_dataclass_attributes', 'get_dataclass_methods', 'get_dataclass_methods', 'get_scope_name', - 'eval_script', 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement', - 'generate_scope', 'get_operation_symbol', 'get_class_attr_namespace_symbol', 'get_ms_class_name', - 'get_ms_class_attr'] +__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope', + 'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol', + 'create_slice_obj', 'get_obj_id', 'get_module_namespace', 'get_obj_type', 'get_object_key', + 'get_ast_type', 'get_node_type', 'get_args', 'get_args_default_values', 'get_ast_namespace_symbol', + 'get_operation_symbol', 'get_operation_namespace_symbol', 'get_parse_method_of_class', 'get_scope_name', + 'eval_script', 'expand_expr_statement', 'is_class_member', 'parse_cb', 'resolve_symbol', + 'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name', + 'is_class_type', 'get_dataclass_attributes', 'get_dataclass_methods'] diff --git a/mindspore/python/mindspore/_extends/parse/namespace.py b/mindspore/python/mindspore/_extends/parse/namespace.py index df904e0a9d7..90a41c3a023 100644 --- a/mindspore/python/mindspore/_extends/parse/namespace.py +++ b/mindspore/python/mindspore/_extends/parse/namespace.py @@ -1,6 +1,6 @@ # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). # -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -118,7 +118,11 @@ class ClassMemberNamespace(Namespace): except ValueError: raise UnboundLocalError(name) except KeyError: - logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.") + # Check if cls is user-defined class decorated with ms_class. If true, an exception will be thrown. + cls = d.__class__ + if hasattr(cls, '__ms_class__'): + raise NotImplementedError(f"'{cls.__name__ }' object has no attribute or method: '{name}'.") + logger.info(f"'{cls.__name__ }' object has no attribute or method: '{name}', so will return None.") raise AttributeError(name) @@ -142,5 +146,4 @@ class ClassAttrNamespace(Namespace): except ValueError: raise UnboundLocalError(name) except KeyError: - logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.") - raise AttributeError(name) + raise AttributeError(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}'.") diff --git a/mindspore/python/mindspore/_extends/parse/parser.py b/mindspore/python/mindspore/_extends/parse/parser.py index 7e0d9a3a111..db442ec9358 100644 --- a/mindspore/python/mindspore/_extends/parse/parser.py +++ b/mindspore/python/mindspore/_extends/parse/parser.py @@ -1,6 +1,6 @@ # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). # -# Copyright 2020-2021 Huawei Technologies Co., Ltd +# Copyright 2020-2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,7 +23,6 @@ import hashlib import inspect import types import importlib -from dataclasses import is_dataclass from textwrap import dedent import asttokens @@ -324,24 +323,26 @@ def get_class_instance_type(obj): """Get the class instance detail type.""" # check the obj type logger.debug("Get the class type(%r)", obj) - class_type = CLASS_INSTANCE_TYPE_INVALID - if _is_class_instance(obj): - if isinstance(obj, nn.Cell): - class_type = CLASS_INSTANCE_TYPE_CELL - elif isinstance(obj, ops.Primitive): - class_type = CLASS_INSTANCE_TYPE_PRIMITIVE - # Add the other type base requirement - return class_type + if isinstance(obj, nn.Cell): + return CLASS_INSTANCE_TYPE_CELL + if isinstance(obj, ops.Primitive): + return CLASS_INSTANCE_TYPE_PRIMITIVE + return CLASS_INSTANCE_TYPE_INVALID -def _is_class_instance(obj): - """Confirm the obj is class instance.""" - return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_dataclass_instance(obj) +def _is_ms_class(obj): + """Check if obj is ms_class object.""" + return hasattr(obj, '__ms_class__') def _is_dataclass_instance(obj): """Check whether a class is an instance of a dataclass (and not a dataclass itself)""" - return is_dataclass(obj) and not isinstance(obj, type) + return hasattr(obj, "__dataclass_fields__") and not isinstance(obj, type) + + +def _is_class_instance(obj): + """Confirm the obj is class instance.""" + return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_dataclass_instance(obj) or _is_ms_class(obj) def _convert_tuple_to_args_kwargs(params): @@ -358,7 +359,7 @@ def _convert_tuple_to_args_kwargs(params): def is_supported_create_instance_type(cls_type): """Check if cls_type is a supported instance type.""" - return issubclass(cls_type, (nn.Cell, ops.Primitive)) + return issubclass(cls_type, (nn.Cell, ops.Primitive)) or _is_ms_class(cls_type) def create_instance(cls_type, params=None): @@ -440,28 +441,19 @@ def get_dataclass_methods(cls): return methods +def is_class_type(cls): + """Check if cls is a class type.""" + return isinstance(cls, type) + + def get_ms_class_name(cls): """Get the name of the class instance decorated by ms_class.""" # Check if cls is nn.Cell. if isinstance(cls, nn.Cell): raise TypeError(f"ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.") if isinstance(cls, type): - name = cls.__name__ - else: - name = cls.__class__.__name__ - # Get the name of cls. - cls_name = cls.__module__ + '.' + name - return cls_name - - -def get_ms_class_attr(cls, name: str): - """Get attribute or method of ms_class obj.""" - # Don't take into account python magic methods and private variables. - if name.startswith('_'): - raise AttributeError(f"{name} is a private variable or magic method, which is not supported.") - if not hasattr(cls, name): - raise AttributeError(f"{cls} has no attribute: {name}.") - return getattr(cls, name) + return cls.__name__ + return cls.__class__.__name__ def convert_to_ms_tensor(data): diff --git a/mindspore/python/mindspore/_extends/parse/standard_method.py b/mindspore/python/mindspore/_extends/parse/standard_method.py index d465bf80052..219c2af2e20 100644 --- a/mindspore/python/mindspore/_extends/parse/standard_method.py +++ b/mindspore/python/mindspore/_extends/parse/standard_method.py @@ -1,6 +1,6 @@ # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). # -# Copyright 2020-2021 Huawei Technologies Co., Ltd +# Copyright 2020-2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,10 +15,7 @@ # limitations under the License. # ============================================================================ """standard_method""" - -from dataclasses import dataclass - -from mindspore import Tensor, Parameter, CSRTensor, COOTensor +from mindspore import Tensor, Parameter, CSRTensor, COOTensor, ms_class from mindspore import dtype as mstype from ..._checkparam import Validator as validator @@ -1828,16 +1825,16 @@ def float_floordiv(x, y): ############# -@dataclass(frozen=True) +@ms_class class SequenceIterator: """ - SequenceIterator is a util dataclass for iterating sequence object. + SequenceIterator is a util class for iterating sequence object. Iterator to use for sequences like List, Array. """ - - idx: int - seq: list + def __init__(self, idx, seq): + self.idx = idx + self.seq = seq @core(ignore_values=True) def __ms_hasnext__(self): diff --git a/tests/st/fallback/test_graph_fallback_class.py b/tests/st/fallback/test_graph_fallback_class.py new file mode 100644 index 00000000000..4bbfbfa98b7 --- /dev/null +++ b/tests/st/fallback/test_graph_fallback_class.py @@ -0,0 +1,451 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test graph fallback """ +import pytest +import numpy as np + +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Tensor, context, ms_class + +context.set_context(mode=context.GRAPH_MODE) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_fallback_class_attr(): + """ + Feature: JIT Fallback + Description: Access the attributes of user-defined classes decorated by ms_class. + Expectation: No exception. + """ + @ms_class + class InnerNet: + def __init__(self): + self.number = Tensor(1, dtype=mstype.int32) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.inner_net = InnerNet() + + def construct(self): + out = self.inner_net.number + return out + + net = Net() + out = net() + assert out.asnumpy() == 1 + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_fallback_class_method(): + """ + Feature: JIT Fallback + Description: Access the methods of user-defined classes decorated by ms_class. + Expectation: No exception. + """ + @ms_class + class InnerNet: + def __init__(self): + self.val = Tensor(2, dtype=mstype.int32) + + def act(self, x, y): + return self.val * (x + y) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.inner_net = InnerNet() + + def construct(self, x, y): + out = self.inner_net.act(x, y) + return out + + x = Tensor(2, dtype=mstype.int32) + y = Tensor(3, dtype=mstype.int32) + net = Net() + out = net(x, y) + assert out.asnumpy() == 10 + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_fallback_class_call(): + """ + Feature: JIT Fallback + Description: Call the __call__ function of user-defined classes decorated by ms_class. + Expectation: No exception. + """ + @ms_class + class InnerNet: + def __init__(self, val): + self.val = val + + def __call__(self, x, y): + return self.val * (x + y) + + class Net(nn.Cell): + def __init__(self, val): + super(Net, self).__init__() + self.inner_net = InnerNet(val) + + def construct(self, x, y): + out = self.inner_net(x, y) + return out + + val = Tensor(2, dtype=mstype.int32) + x = Tensor(3, dtype=mstype.int32) + y = Tensor(4, dtype=mstype.int32) + net = Net(val) + out = net(x, y) + assert out.asnumpy() == 14 + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_fallback_class_input_attr(): + """ + Feature: JIT Fallback + Description: Access the attributes of user-defined classes decorated by ms_class. + Expectation: No exception. + """ + @ms_class + class InnerNet: + def __init__(self): + self.number = Tensor(np.array([1, 2, 3])) + + class Net(nn.Cell): + def __init__(self, net): + super(Net, self).__init__() + self.inner_net = net() + + def construct(self): + out = self.inner_net.number + return out + + net = Net(InnerNet) + out = net() + expect_res = np.array([1, 2, 3]) + assert np.all(out.asnumpy() == expect_res) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_fallback_class_input_method(): + """ + Feature: JIT Fallback + Description: Access the methods of user-defined classes decorated by ms_class. + Expectation: No exception. + """ + @ms_class + class InnerNet: + def __init__(self): + self.val = Tensor(2, dtype=mstype.int32) + + def act(self, x, y): + return self.val * (x + y) + + class Net(nn.Cell): + def __init__(self, net): + super(Net, self).__init__() + self.inner_net = net() + + def construct(self): + out = self.inner_net.act(1, 2) + return out + + net = Net(InnerNet) + out = net() + assert out.asnumpy() == 6 + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_fallback_class_class_nested(): + """ + Feature: JIT Fallback + Description: Test nested ms_class in graph. + Expectation: No exception. + """ + @ms_class + class Inner: + def __init__(self): + self.number = Tensor(1, dtype=mstype.int32) + + @ms_class + class InnerNet: + def __init__(self): + self.inner = Inner() + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.inner_net = InnerNet() + + def construct(self): + out = self.inner_net.inner.number + return out + + net = Net() + out = net() + assert out.asnumpy() == 1 + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_fallback_class_cell_nested(): + """ + Feature: JIT Fallback + Description: Test nested ms_class and cell in graph. + Expectation: No exception. + """ + class Net(nn.Cell): + def __init__(self, val): + super().__init__() + self.val = val + + def construct(self, x): + return x + self.val + + @ms_class + class TrainNet(): + class Loss(nn.Cell): + def __init__(self, net): + super().__init__() + self.net = net + + def construct(self, x): + out = self.net(x) + return out * 2 + + def __init__(self, net): + self.net = net + loss_net = self.Loss(self.net) + self.number = loss_net(10) + + global_net = Net(1) + class LearnNet(nn.Cell): + def __init__(self): + super().__init__() + self.value = TrainNet(global_net).number + + def construct(self, x): + return x + self.value + + leanrn_net = LearnNet() + out = leanrn_net(3) + print(out) + assert out == 25 + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_fallback_class_type_attr(): + """ + Feature: JIT Fallback + Description: Access the attributes of class type. + Expectation: No exception. + """ + @ms_class + class InnerNet: + val = Tensor(2, dtype=mstype.int32) + + def act(self, x, y): + return self.val * (x + y) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.inner_net = InnerNet + + # Support accessing attributes of class type, but do not support + # accessing methods, e.g. self.inner_net.act(1, 2) + def construct(self): + out = self.inner_net.val + return out + + net = Net() + out = net() + assert out == 2 + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_fallback_class_create_instance_attr(): + """ + Feature: JIT Fallback + Description: Access the attributes of the created class instance. + Expectation: No exception. + """ + @ms_class + class InnerNet: + def __init__(self, val): + self.number = val + 3 + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.inner_net = InnerNet + + def construct(self, x): + net = self.inner_net(x) + return net.number + + net = Net() + out = net(2) + assert out == 5 + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_fallback_class_create_instance_method(): + """ + Feature: JIT Fallback + Description: Access the methods of the created class instance. + Expectation: No exception. + """ + @ms_class + class InnerNet: + def __init__(self, val): + self.number = val + + def act(self, x, y): + return self.number * (x + y) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.inner_net = InnerNet + + def construct(self, x, y, z): + net = self.inner_net(x) + return net.act(y, z) + + x = 2 + y = Tensor(2, dtype=mstype.int32) + z = Tensor(3, dtype=mstype.int32) + net = Net() + out = net(x, y, z) + assert out.asnumpy() == 10 + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_fallback_class_create_instance_call(): + """ + Feature: JIT Fallback + Description: Call the __call__ function of the created class instance. + Expectation: No exception. + """ + @ms_class + class InnerNet: + def __init__(self, number): + self.number = number + + def __call__(self, x, y): + return self.number * (x + y) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.inner_net = InnerNet + + def construct(self, x, y, z): + net = self.inner_net(x) + out = net(y, z) + return out + + x = 2 + y = Tensor(2, dtype=mstype.int32) + z = Tensor(3, dtype=mstype.int32) + net = Net() + out = net(x, y, z) + assert out == 10 + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_fallback_raise_error_not_class_type(): + """ + Feature: JIT Fallback + Description: Decorator ms_class cannot be used for non-class types. + Expectation: No exception. + """ + with pytest.raises(TypeError): + @ms_class + def func(x, y): + return x + y + + func(1, 2) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_fallback_raise_error_decorate_cell(): + """ + Feature: JIT Fallback + Description: Decorator ms_class cannot be used for nn.Cell + Expectation: No exception. + """ + @ms_class + class Net(nn.Cell): + def construct(self, x): + return x + + with pytest.raises(TypeError): + x = Tensor(1) + net = Net() + net(x) diff --git a/tests/ut/python/fallback/test_graph_fallback_class.py b/tests/ut/python/fallback/test_graph_fallback_class.py deleted file mode 100644 index 5105eab7a60..00000000000 --- a/tests/ut/python/fallback/test_graph_fallback_class.py +++ /dev/null @@ -1,416 +0,0 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" test graph fallback """ -import pytest -import numpy as np - -import mindspore.nn as nn -import mindspore.common.dtype as mstype -from mindspore import Tensor, context, ms_class, ms_function -from . import test_graph_fallback - -context.set_context(mode=context.GRAPH_MODE) - - -def test_fallback_self_attr(): - """ - Feature: JIT Fallback - Description: Test self.attr in graph. - Expectation: No exception. - """ - class Network(nn.Cell): - def __init__(self): - super(Network, self).__init__() - self.dim = 1 - - def construct(self, x): - batch = x.shape[0] - one = Tensor(np.ones([batch, self.dim]), mstype.float32) - return one * x - - net = Network() - x = Tensor([1, 2], mstype.float32) - out = net(x) - expect = np.array([[1., 2.], [1., 2.]]) - assert np.allclose(out.asnumpy(), expect, 1.e-2, 1.e-2) - - -def test_fallback_self_attr_fn(): - """ - Feature: JIT Fallback - Description: Test self.attr in graph. - Expectation: No exception. - """ - class Network(nn.Cell): - def __init__(self, fn): - super(Network, self).__init__() - self.fn = fn - - def construct(self): - x = np.array([1, 2, 3]) - y = np.array([3, 4, 5]) - out = Tensor(self.fn(x, y)) - return out - - def fn(x, y): - return x + y - - net = Network(fn) - out = net() - expect = np.array([4, 6, 8]) - assert np.all(out.asnumpy() == expect) - - -def test_fallback_self_attr_attr(): - """ - Feature: JIT Fallback - Description: Test self.attr in graph. - Expectation: No exception. - """ - class Network(nn.Cell): - def __init__(self): - super(Network, self).__init__() - self.value = [2, 2, 3] - - def construct(self): - x = np.array(self.value.count(2)) - return Tensor(x) - - net = Network() - out = net() - assert out == 2 - - -def test_fallback_self_method(): - """ - Feature: JIT Fallback - Description: Test self.method in graph. - Expectation: No exception. - """ - class Network(nn.Cell): - def construct(self): - x = np.array([1, 2, 3]) - y = np.array([3, 4, 5]) - out = Tensor(self.fn(x, y)) - return out - - def fn(self, x, y): - return x + y - - net = Network() - out = net() - expect = np.array([4, 6, 8]) - assert np.all(out.asnumpy() == expect) - - -@pytest.mark.skip(reason='Not support in graph jit fallback feature yet') -def test_fallback_self_method_tensor(): - """ - Feature: JIT Fallback - Description: Test self.method in graph. - Expectation: No exception. - """ - class Network(nn.Cell): - def construct(self): - x = np.array([1, 2, 3]) - y = np.array([3, 4, 5]) - z = self.fn(x, y) - out = Tensor(z) - return out - - def fn(self, x, y): - return x + y - - net = Network() - out = net() - print(out) - - -def test_fallback_import_modules(): - """ - Feature: JIT Fallback - Description: add_func is defined in test_graph_fallback.py - Expectation: No exception. - """ - @ms_function - def use_imported_module(x, y): - out = test_graph_fallback.add_func(x, y) - return out - - x = Tensor(2, dtype=mstype.int32) - y = Tensor(3, dtype=mstype.int32) - out = use_imported_module(x, y) - print(out) - - -def test_fallback_class_attr(): - """ - Feature: JIT Fallback - Description: Test user-defined class attributes in graph. - Expectation: No exception. - """ - @ms_class - class InnerNet: - def __init__(self): - self.number = 1 - - class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.inner_net = InnerNet() - - def construct(self): - out = self.inner_net.number - return out - - net = Net() - out = net() - assert out == 1 - - -def test_fallback_class_method(): - """ - Feature: JIT Fallback - Description: Test user-defined class methods in graph. - Expectation: No exception. - """ - @ms_class - class InnerNet: - def __init__(self): - self.val = 2 - - def act(self, x, y): - return self.val * (x + y) - - class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.inner_net = InnerNet() - - def construct(self): - out = self.inner_net.act(1, 2) - return out - - net = Net() - out = net() - assert out == 6 - - -def test_fallback_class_input_attr(): - """ - Feature: JIT Fallback - Description: Test user-defined class attributes in graph. - Expectation: No exception. - """ - @ms_class - class InnerNet: - def __init__(self): - self.number = Tensor(np.array([1, 2, 3])) - - class Net(nn.Cell): - def __init__(self, net): - super(Net, self).__init__() - self.inner_net = net() - - def construct(self): - out = self.inner_net.number - return out - - net = Net(InnerNet) - out = net() - expect_res = np.array([1, 2, 3]) - assert np.all(out.asnumpy() == expect_res) - - -def test_fallback_class_input_method(): - """ - Feature: JIT Fallback - Description: Test user-defined class methods in graph. - Expectation: No exception. - """ - @ms_class - class InnerNet: - def __init__(self): - self.val = 2 - - def act(self, x, y): - return self.val * (x + y) - - class Net(nn.Cell): - def __init__(self, net): - super(Net, self).__init__() - self.inner_net = net() - - def construct(self): - out = self.inner_net.act(1, 2) - return out - - net = Net(InnerNet) - out = net() - assert out == 6 - - -def test_fallback_class_class_nested(): - """ - Feature: JIT Fallback - Description: Test nested ms_class in graph. - Expectation: No exception. - """ - @ms_class - class Inner: - def __init__(self): - self.number = 1 - - @ms_class - class InnerNet: - def __init__(self): - self.inner = Inner() - - class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.inner_net = InnerNet() - - def construct(self): - out = self.inner_net.inner.number - return out - - net = Net() - out = net() - assert out == 1 - - -def test_fallback_class_cell_nested(): - """ - Feature: JIT Fallback - Description: Test nested ms_class and cell in graph. - Expectation: No exception. - """ - class Net(nn.Cell): - def __init__(self, val): - super().__init__() - self.val = val - - def construct(self, x): - return x + self.val - - @ms_class - class TrainNet(): - class Loss(nn.Cell): - def __init__(self, net): - super().__init__() - self.net = net - - def construct(self, x): - out = self.net(x) - return out * 2 - - def __init__(self, net): - self.net = net - loss_net = self.Loss(self.net) - self.number = loss_net(10) - - global_net = Net(1) - class LearnNet(nn.Cell): - def __init__(self): - super().__init__() - self.value = TrainNet(global_net).number - - def construct(self, x): - return x + self.value - - leanrn_net = LearnNet() - out = leanrn_net(3) - print(out) - assert out == 25 - - -@pytest.mark.skip(reason='Not support in graph yet') -def test_fallback_class_isinstance(): - """ - Feature: JIT Fallback - Description: Test ms_class in graph. - Expectation: No exception. - """ - @ms_class - class InnerNet: - def __init__(self): - self.number = 1 - - class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.inner_net = InnerNet() - - def construct(self, x): - if isinstance(self.inner_net, InnerNet): - return x + 10 - return x - - net = Net() - out = net(5) - assert out == 15 - - -def test_fallback_raise_error_not_class_type(): - """ - Feature: JIT Fallback - Description: Test ms_class in graph. - Expectation: No exception. - """ - with pytest.raises(TypeError): - @ms_class - def func(x, y): - return x + y - - func(1, 2) - - -def test_fallback_raise_error_not_class_instance(): - """ - Feature: JIT Fallback - Description: Test ms_class in graph. - Expectation: No exception. - """ - @ms_class - class InnerNet: - def __init__(self): - self.number = 1 - - class Net(nn.Cell): - def construct(self): - out = InnerNet().number - return out - - with pytest.raises(ValueError): - net = Net() - net() - - -def test_fallback_raise_error_decorate_cell(): - """ - Feature: JIT Fallback - Description: Test ms_class in graph. - Expectation: No exception. - """ - @ms_class - class Net(nn.Cell): - def construct(self, x): - return x - - with pytest.raises(TypeError): - x = Tensor(1) - net = Net() - net(x) diff --git a/tests/ut/python/fallback/test_graph_fallback_self.py b/tests/ut/python/fallback/test_graph_fallback_self.py new file mode 100644 index 00000000000..cbdac70a4ff --- /dev/null +++ b/tests/ut/python/fallback/test_graph_fallback_self.py @@ -0,0 +1,131 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test graph fallback """ +import numpy as np + +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Tensor, context, ms_function +from . import test_graph_fallback + +context.set_context(mode=context.GRAPH_MODE) + + +def test_fallback_self_attr(): + """ + Feature: JIT Fallback + Description: Use self.attr in expressions supported by JIT Fallback. + Expectation: No exception. + """ + class Network(nn.Cell): + def __init__(self): + super(Network, self).__init__() + self.dim = 1 + + def construct(self, x): + batch = x.shape[0] + one = Tensor(np.ones([batch, self.dim]), mstype.float32) + return one * x + + net = Network() + x = Tensor([1, 2], mstype.float32) + out = net(x) + expect = np.array([[1., 2.], [1., 2.]]) + assert np.allclose(out.asnumpy(), expect, 1.e-2, 1.e-2) + + +def test_fallback_self_attr_fn(): + """ + Feature: JIT Fallback + Description: Use self.attr of type function in expressions supported by JIT Fallback. + Expectation: No exception. + """ + class Network(nn.Cell): + def __init__(self, fn): + super(Network, self).__init__() + self.fn = fn + + def construct(self): + x = np.array([1, 2, 3]) + y = np.array([3, 4, 5]) + out = Tensor(self.fn(x, y)) + return out + + def fn(x, y): + return x + y + + net = Network(fn) + out = net() + expect = np.array([4, 6, 8]) + assert np.all(out.asnumpy() == expect) + + +def test_fallback_self_attr_attr(): + """ + Feature: JIT Fallback + Description: In expressions supported by JIT Fallback, use the attribute of self.attr. + Expectation: No exception. + """ + class Network(nn.Cell): + def __init__(self): + super(Network, self).__init__() + self.value = [2, 2, 3] + + def construct(self): + x = np.array(self.value.count(2)) + return Tensor(x) + + net = Network() + out = net() + assert out == 2 + + +def test_fallback_self_method(): + """ + Feature: JIT Fallback + Description: Use self.method in expressions supported by JIT Fallback. + Expectation: No exception. + """ + class Network(nn.Cell): + def construct(self): + x = np.array([1, 2, 3]) + y = np.array([3, 4, 5]) + out = Tensor(self.fn(x, y)) + return out + + def fn(self, x, y): + return x + y + + net = Network() + out = net() + expect = np.array([4, 6, 8]) + assert np.all(out.asnumpy() == expect) + + +def test_fallback_import_modules(): + """ + Feature: JIT Fallback + Description: Check whether the call to the third-party library is correct. It has nothing to do with class. + Expectation: No exception. + """ + @ms_function + def use_imported_module(x, y): + out = test_graph_fallback.add_func(x, y) + return out + + x = Tensor(2, dtype=mstype.int32) + y = Tensor(3, dtype=mstype.int32) + out = use_imported_module(x, y) + print(out)