!32613 Supports creating and calling instances of ms_class
Merge pull request !32613 from huangbingjian/class_dev
This commit is contained in:
commit
4c3faa7f8f
|
@ -658,13 +658,19 @@ ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) {
|
||||||
return class_type;
|
return class_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the object is Cell Instance.
|
// Check if the object is Cell instance.
|
||||||
bool IsCellInstance(const py::object &obj) {
|
bool IsCellInstance(const py::object &obj) {
|
||||||
auto class_type = GetClassInstanceType(obj);
|
auto class_type = GetClassInstanceType(obj);
|
||||||
bool is_cell = (class_type == CLASS_INSTANCE_TYPE_CELL);
|
bool is_cell = (class_type == CLASS_INSTANCE_TYPE_CELL);
|
||||||
return is_cell;
|
return is_cell;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if the object is class type.
|
||||||
|
bool IsClassType(const py::object &obj) {
|
||||||
|
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||||
|
return python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_CLASS_TYPE, obj).cast<bool>();
|
||||||
|
}
|
||||||
|
|
||||||
// Create the python class instance.
|
// Create the python class instance.
|
||||||
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs) {
|
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs) {
|
||||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
/**
|
/**
|
||||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||||
*
|
*
|
||||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -44,6 +44,7 @@ ResolveTypeDef GetObjType(const py::object &obj);
|
||||||
ClassInstanceTypeDef GetClassInstanceType(const py::object &obj);
|
ClassInstanceTypeDef GetClassInstanceType(const py::object &obj);
|
||||||
|
|
||||||
bool IsCellInstance(const py::object &obj);
|
bool IsCellInstance(const py::object &obj);
|
||||||
|
bool IsClassType(const py::object &obj);
|
||||||
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs);
|
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs);
|
||||||
py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs);
|
py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs);
|
||||||
void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name);
|
void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name);
|
||||||
|
|
|
@ -67,8 +67,8 @@ const char PYTHON_MOD_CREATE_INSTANCE[] = "create_instance";
|
||||||
const char PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE[] = "is_supported_create_instance_type";
|
const char PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE[] = "is_supported_create_instance_type";
|
||||||
const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes";
|
const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes";
|
||||||
const char PYTHON_MOD_GET_DATACLASS_METHODS[] = "get_dataclass_methods";
|
const char PYTHON_MOD_GET_DATACLASS_METHODS[] = "get_dataclass_methods";
|
||||||
|
const char PYTHON_MOD_IS_CLASS_TYPE[] = "is_class_type";
|
||||||
const char PYTHON_MOD_GET_MS_CLASS_NAME[] = "get_ms_class_name";
|
const char PYTHON_MOD_GET_MS_CLASS_NAME[] = "get_ms_class_name";
|
||||||
const char PYTHON_MOD_GET_MS_CLASS_ATTR[] = "get_ms_class_attr";
|
|
||||||
const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
|
const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
|
||||||
const char PYTHON_MOD_GET_ATTR_NAMESPACE_SYMBOL[] = "get_class_attr_namespace_symbol";
|
const char PYTHON_MOD_GET_ATTR_NAMESPACE_SYMBOL[] = "get_class_attr_namespace_symbol";
|
||||||
const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol";
|
const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol";
|
||||||
|
|
|
@ -67,6 +67,7 @@ struct AnfDumpHandlerRegister {
|
||||||
}
|
}
|
||||||
} callback_register;
|
} callback_register;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
abstract::AbstractBasePtr ClassObject::ToAbstract() {
|
abstract::AbstractBasePtr ClassObject::ToAbstract() {
|
||||||
ClassPtr cls_ptr = ParseDataClass(obj());
|
ClassPtr cls_ptr = ParseDataClass(obj());
|
||||||
auto abs_scalar = std::make_shared<abstract::AbstractScalar>();
|
auto abs_scalar = std::make_shared<abstract::AbstractScalar>();
|
||||||
|
@ -78,6 +79,24 @@ abstract::AbstractBasePtr ClassObject::ToAbstract() {
|
||||||
return std::make_shared<abstract::PartialAbstractClosure>(func_ptr, args_spec_list);
|
return std::make_shared<abstract::PartialAbstractClosure>(func_ptr, args_spec_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr MsClassObject::ToAbstract() {
|
||||||
|
auto abs_scalar =
|
||||||
|
std::make_shared<abstract::AbstractScalar>(shared_from_base<MsClassObject>(), std::make_shared<MsClassType>());
|
||||||
|
AbstractBasePtrList args_spec_list = {abs_scalar};
|
||||||
|
abstract::PrimitiveAbstractClosurePtr func_ptr = nullptr;
|
||||||
|
bool is_class_type = parse::data_converter::IsClassType(obj());
|
||||||
|
if (is_class_type) {
|
||||||
|
// Class type as func, such as Net(x, y)
|
||||||
|
func_ptr = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimCreateInstance);
|
||||||
|
} else {
|
||||||
|
// Class instance as func, such as net(x, y)
|
||||||
|
func_ptr = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimCallInstance);
|
||||||
|
}
|
||||||
|
auto ret_val = std::make_shared<abstract::PartialAbstractClosure>(func_ptr, args_spec_list);
|
||||||
|
ret_val->set_value_desc(ToString());
|
||||||
|
return ret_val;
|
||||||
|
}
|
||||||
|
|
||||||
static inline bool IsSupportedCreateInstanceType(const py::object &obj) {
|
static inline bool IsSupportedCreateInstanceType(const py::object &obj) {
|
||||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||||
auto res = python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE, obj);
|
auto res = python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE, obj);
|
||||||
|
@ -520,14 +539,15 @@ AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const MsCl
|
||||||
MS_LOG(DEBUG) << "Resolve ms_class obj (" << ms_class->name() << ") with attr " << attr << ".";
|
MS_LOG(DEBUG) << "Resolve ms_class obj (" << ms_class->name() << ") with attr " << attr << ".";
|
||||||
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
|
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
|
||||||
|
|
||||||
|
constexpr size_t prefix_index = 0;
|
||||||
|
if (attr.size() > 0 && attr[prefix_index] == '_') {
|
||||||
|
MS_LOG(EXCEPTION) << attr << " is a private variable or magic method, which is not supported.";
|
||||||
|
}
|
||||||
py::object cls_obj = ms_class->obj();
|
py::object cls_obj = ms_class->obj();
|
||||||
if (!py::hasattr(cls_obj, attr.c_str())) {
|
if (!py::hasattr(cls_obj, common::SafeCStr(attr))) {
|
||||||
MS_LOG(EXCEPTION) << ms_class->name() << " has not attribute: " << attr << ".";
|
MS_LOG(EXCEPTION) << ms_class->name() << " has not attribute: " << attr << ".";
|
||||||
}
|
}
|
||||||
|
py::object attr_obj = py::getattr(cls_obj, common::SafeCStr(attr));
|
||||||
const std::string fn = PYTHON_MOD_GET_MS_CLASS_ATTR;
|
|
||||||
const std::string module = "mindspore._extends.parse.parser";
|
|
||||||
py::object attr_obj = python_adapter::GetPyFn(module, fn)(cls_obj, attr);
|
|
||||||
AnfNodePtr res_node = ResolveObjectAndAddToManager(manager, attr_obj, node);
|
AnfNodePtr res_node = ResolveObjectAndAddToManager(manager, attr_obj, node);
|
||||||
TraceManager::ClearParseOrResolveDebugInfo();
|
TraceManager::ClearParseOrResolveDebugInfo();
|
||||||
return res_node;
|
return res_node;
|
||||||
|
|
|
@ -116,6 +116,7 @@ class PyObjectWrapper : public Named {
|
||||||
// the object that needs to be resolved
|
// the object that needs to be resolved
|
||||||
py::object obj_;
|
py::object obj_;
|
||||||
};
|
};
|
||||||
|
using PyObjectWrapperPtr = std::shared_ptr<PyObjectWrapper>;
|
||||||
|
|
||||||
// InterpretedObject class wrappers interpreted python object.
|
// InterpretedObject class wrappers interpreted python object.
|
||||||
class InterpretedObject final : public PyObjectWrapper {
|
class InterpretedObject final : public PyObjectWrapper {
|
||||||
|
@ -137,9 +138,7 @@ class MsClassObject final : public PyObjectWrapper {
|
||||||
: PyObjectWrapper(obj, "MsClassObject: \'" + name + "\'") {}
|
: PyObjectWrapper(obj, "MsClassObject: \'" + name + "\'") {}
|
||||||
~MsClassObject() override = default;
|
~MsClassObject() override = default;
|
||||||
MS_DECLARE_PARENT(MsClassObject, PyObjectWrapper);
|
MS_DECLARE_PARENT(MsClassObject, PyObjectWrapper);
|
||||||
abstract::AbstractBasePtr ToAbstract() override {
|
abstract::AbstractBasePtr ToAbstract() override;
|
||||||
return std::make_shared<abstract::AbstractScalar>(shared_from_base<MsClassObject>(), std::make_shared<External>());
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
using MsClassObjectPtr = std::shared_ptr<MsClassObject>;
|
using MsClassObjectPtr = std::shared_ptr<MsClassObject>;
|
||||||
|
|
||||||
|
|
|
@ -1304,6 +1304,35 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng
|
||||||
return StaticGetterInferred(converted_value, data_conf, out_conf);
|
return StaticGetterInferred(converted_value, data_conf, out_conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value,
|
||||||
|
const ValuePtr &data_value, const ConfigPtr &data_conf,
|
||||||
|
const AnfNodeConfigPtr &out_conf) {
|
||||||
|
MS_EXCEPTION_IF_NULL(item_value);
|
||||||
|
MS_EXCEPTION_IF_NULL(data_value);
|
||||||
|
// Get the name of item.
|
||||||
|
if (!item_value->isa<StringImm>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Expect a string, but got: " << item_value->ToString();
|
||||||
|
}
|
||||||
|
std::string item_name = item_value->cast<StringImmPtr>()->value();
|
||||||
|
// Get ms_class object.
|
||||||
|
if (!data_value->isa<parse::MsClassObject>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Expect a ms_class object, but got " << data_value->ToString();
|
||||||
|
}
|
||||||
|
auto ms_class = data_value->cast<parse::MsClassObjectPtr>();
|
||||||
|
MS_LOG(DEBUG) << "Resolve ms_class (" << ms_class->name() << ") with item " << item_name << ".";
|
||||||
|
|
||||||
|
// Get the attr/method of ms_class object.
|
||||||
|
auto out_node = out_conf->node();
|
||||||
|
FuncGraphPtr func_graph = out_node->func_graph();
|
||||||
|
auto new_node = ResolveMsClassWithAttr(func_graph->manager(), ms_class, item_name, out_node);
|
||||||
|
// Replace old node with the resolved new node in order list.
|
||||||
|
func_graph->ReplaceInOrder(out_node, new_node);
|
||||||
|
AnalysisEnginePtr eng = out_conf->engine();
|
||||||
|
MS_EXCEPTION_IF_NULL(eng);
|
||||||
|
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
|
||||||
|
return eng->ForwardConfig(out_conf, fn_conf);
|
||||||
|
}
|
||||||
|
|
||||||
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value,
|
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value,
|
||||||
const TypePtr &data_type, const ConfigPtr &data_conf,
|
const TypePtr &data_type, const ConfigPtr &data_conf,
|
||||||
const AnfNodeConfigPtr &out_conf) {
|
const AnfNodeConfigPtr &out_conf) {
|
||||||
|
@ -1363,17 +1392,45 @@ int64_t GetResolveType(const TypePtr &data_type) {
|
||||||
return kResolveTypeFunction;
|
return kResolveTypeFunction;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ValuePtr GetMsClassObject(const AbstractBasePtr &abs) {
|
||||||
|
if (!abs->isa<abstract::PartialAbstractClosure>()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto partial_abs = abs->cast<abstract::PartialAbstractClosurePtr>();
|
||||||
|
auto fn = partial_abs->fn();
|
||||||
|
if (!fn->isa<abstract::PrimitiveAbstractClosure>()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
// Check if type is kObjectTypeClass.
|
||||||
|
auto args = partial_abs->args();
|
||||||
|
if (args.size() > 0) {
|
||||||
|
constexpr size_t first_input_index = 0;
|
||||||
|
auto first_arg = args[first_input_index];
|
||||||
|
MS_EXCEPTION_IF_NULL(first_arg);
|
||||||
|
auto type = first_arg->BuildType();
|
||||||
|
MS_EXCEPTION_IF_NULL(type);
|
||||||
|
if (type->type_id() == kObjectTypeClass) {
|
||||||
|
return first_arg->BuildValue();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||||
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
|
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
|
||||||
// Inputs: namespace and its static function; or class and its member function
|
// Inputs: namespace and its static function; or class and its member function
|
||||||
CheckArgsSize("StaticGetter", args_spec_list, 2);
|
CheckArgsSize("StaticGetter", args_spec_list, 2);
|
||||||
|
|
||||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
constexpr size_t data_index = 0;
|
||||||
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
|
constexpr size_t item_index = 1;
|
||||||
MS_LOG(DEBUG) << "Args[0]: " << args_spec_list[0]->ToString();
|
auto data_args = args_spec_list[data_index];
|
||||||
MS_LOG(DEBUG) << "Args[1]: " << args_spec_list[1]->ToString();
|
auto item_args = args_spec_list[item_index];
|
||||||
TypePtr data_type = args_spec_list[0]->BuildType();
|
MS_EXCEPTION_IF_NULL(data_args);
|
||||||
ValuePtr item_value = args_spec_list[1]->BuildValue();
|
MS_EXCEPTION_IF_NULL(item_args);
|
||||||
|
MS_LOG(DEBUG) << "StaticGetter, data: " << data_args->ToString() << ", item: " << item_args->ToString();
|
||||||
|
TypePtr data_type = data_args->BuildType();
|
||||||
|
ValuePtr item_value = item_args->BuildValue();
|
||||||
|
|
||||||
ScopePtr scope = kDefaultScope;
|
ScopePtr scope = kDefaultScope;
|
||||||
if (out_conf != nullptr) {
|
if (out_conf != nullptr) {
|
||||||
scope = out_conf->node()->scope();
|
scope = out_conf->node()->scope();
|
||||||
|
@ -1384,6 +1441,10 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
|
||||||
MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
|
MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto class_value = GetMsClassObject(data_args);
|
||||||
|
if (class_value != nullptr) {
|
||||||
|
return GetEvaluatedValueForMsClassAttrOrMethod(engine, item_value, class_value, data_conf, out_conf);
|
||||||
|
}
|
||||||
int64_t resolve_type = GetResolveType(data_type);
|
int64_t resolve_type = GetResolveType(data_type);
|
||||||
if (resolve_type == kResolveTypeUserDefineClass) {
|
if (resolve_type == kResolveTypeUserDefineClass) {
|
||||||
return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf);
|
return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf);
|
||||||
|
@ -1581,46 +1642,47 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
||||||
MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
|
MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
|
||||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
|
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
|
||||||
const AnfNodeConfigPtr &out_conf) override {
|
const AnfNodeConfigPtr &out_conf) override {
|
||||||
|
// Check the type parameter.
|
||||||
if (args_spec_list.empty()) {
|
if (args_spec_list.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
|
MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
|
||||||
}
|
}
|
||||||
|
constexpr size_t type_index = 0;
|
||||||
// Get the type parameter.
|
auto arg_class_type = args_spec_list[type_index];
|
||||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
MS_EXCEPTION_IF_NULL(arg_class_type);
|
||||||
TypePtr type = args_spec_list[0]->GetTypeTrack();
|
TypePtr type = arg_class_type->GetTypeTrack();
|
||||||
MS_EXCEPTION_IF_NULL(type);
|
MS_EXCEPTION_IF_NULL(type);
|
||||||
if (type->type_id() != kMetaTypeTypeType) {
|
if (type->type_id() != kMetaTypeTypeType && type->type_id() != kObjectTypeClass) {
|
||||||
MS_LOG(EXCEPTION) << "CreateInstanceEvaluator require first parameter should be an object of TypeType, but got "
|
MS_LOG(EXCEPTION)
|
||||||
<< type->ToString();
|
<< "CreateInstanceEvaluator require first parameter should be an object of TypeType or TypeClass, but got "
|
||||||
|
<< type->ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr value_track = args_spec_list[0]->GetValueTrack();
|
ValuePtr value_track = arg_class_type->GetValueTrack();
|
||||||
MS_EXCEPTION_IF_NULL(value_track);
|
MS_EXCEPTION_IF_NULL(value_track);
|
||||||
|
parse::PyObjectWrapperPtr type_obj = dyn_cast<parse::PyObjectWrapper>(value_track);
|
||||||
std::shared_ptr<parse::PyObjectWrapper> type_obj = dyn_cast<parse::PyObjectWrapper>(value_track);
|
|
||||||
if (type_obj == nullptr) {
|
if (type_obj == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
|
MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
|
||||||
}
|
}
|
||||||
|
if (!type_obj->isa<parse::ClassType>() && !type_obj->isa<parse::MsClassObject>()) {
|
||||||
if (!type_obj->isa<parse::ClassType>()) {
|
MS_LOG(EXCEPTION)
|
||||||
MS_LOG(EXCEPTION) << "CreateInstanceEvaluator the type_obj should be an object of ClassType, but got "
|
<< "CreateInstanceEvaluator the type_obj should be an object of ClassType or MsClassObject, but got "
|
||||||
<< type_obj->ToString() << ".";
|
<< type_obj->ToString() << ".";
|
||||||
}
|
}
|
||||||
|
|
||||||
auto class_type = type_obj->obj();
|
auto class_type = type_obj->obj();
|
||||||
MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << ".";
|
MS_LOG(DEBUG) << "Get class type: " << type_obj->ToString() << ".";
|
||||||
|
|
||||||
// Get the create instance obj's parameters, `params` may contain tuple(args, kwargs).
|
// Get the create instance obj's parameters, `params` may contain tuple(args, kwargs).
|
||||||
py::tuple params = GetParameters(args_spec_list);
|
py::tuple params = GetParameters(args_spec_list);
|
||||||
|
|
||||||
// Create class instance.
|
// Create class instance.
|
||||||
auto obj = parse::data_converter::CreatePythonObject(class_type, params);
|
auto obj = parse::data_converter::CreatePythonObject(class_type, params);
|
||||||
if (py::isinstance<py::none>(obj)) {
|
if (py::isinstance<py::none>(obj)) {
|
||||||
MS_LOG(EXCEPTION) << "Create python object `" << py::str(class_type)
|
MS_LOG(EXCEPTION) << "Create python object `" << py::str(class_type)
|
||||||
<< "` failed, only support to create \'Cell\' or \'Primitive\' object.";
|
<< "` failed, only support to create \'Cell\', \'Primitive\' or "
|
||||||
|
<< "user-defined Class decorated with \'ms_class\'.";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the object.
|
// Process the object.
|
||||||
|
MS_EXCEPTION_IF_NULL(out_conf->node());
|
||||||
TraceGuard guard(std::make_shared<TraceResolve>(out_conf->node()->debug_info()));
|
TraceGuard guard(std::make_shared<TraceResolve>(out_conf->node()->debug_info()));
|
||||||
ValuePtr converted_ret = nullptr;
|
ValuePtr converted_ret = nullptr;
|
||||||
bool converted = parse::ConvertData(obj, &converted_ret, true);
|
bool converted = parse::ConvertData(obj, &converted_ret, true);
|
||||||
|
@ -1628,7 +1690,6 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
||||||
MS_LOG(EXCEPTION) << "Convert the python object failed";
|
MS_LOG(EXCEPTION) << "Convert the python object failed";
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(converted_ret);
|
MS_EXCEPTION_IF_NULL(converted_ret);
|
||||||
|
|
||||||
if (converted_ret->isa<FuncGraph>()) {
|
if (converted_ret->isa<FuncGraph>()) {
|
||||||
AddToManager(engine, converted_ret->cast<FuncGraphPtr>());
|
AddToManager(engine, converted_ret->cast<FuncGraphPtr>());
|
||||||
}
|
}
|
||||||
|
@ -1664,6 +1725,63 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class CallInstanceEvaluator : public TransitionPrimEvaluator {
|
||||||
|
public:
|
||||||
|
CallInstanceEvaluator() : TransitionPrimEvaluator("CallInstanceEvaluator") {}
|
||||||
|
~CallInstanceEvaluator() override = default;
|
||||||
|
MS_DECLARE_PARENT(CallInstanceEvaluator, TransitionPrimEvaluator);
|
||||||
|
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
|
||||||
|
const AnfNodeConfigPtr &out_conf) override {
|
||||||
|
if (args_spec_list.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "args_spec_list should not be empty.";
|
||||||
|
}
|
||||||
|
constexpr size_t cls_index = 0;
|
||||||
|
auto arg_cls = args_spec_list[cls_index];
|
||||||
|
MS_EXCEPTION_IF_NULL(arg_cls);
|
||||||
|
TypePtr type = arg_cls->GetTypeTrack();
|
||||||
|
MS_EXCEPTION_IF_NULL(type);
|
||||||
|
if (type->type_id() != kObjectTypeClass) {
|
||||||
|
MS_LOG(EXCEPTION) << "CallInstanceEvaluator require first parameter should be an object of TypeClass, but got "
|
||||||
|
<< type->ToString();
|
||||||
|
}
|
||||||
|
ValuePtr value_track = arg_cls->GetValueTrack();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_track);
|
||||||
|
parse::MsClassObjectPtr ms_class = dyn_cast<parse::MsClassObject>(value_track);
|
||||||
|
if (ms_class == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "CallInstanceEvaluator only supports MsClassObject.";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call class instance, net(x, y) -> net.__call__(x, y)
|
||||||
|
py::object cls_obj = ms_class->obj();
|
||||||
|
const std::string call_func = "__call__";
|
||||||
|
if (!py::hasattr(cls_obj, common::SafeCStr(call_func))) {
|
||||||
|
MS_LOG(EXCEPTION) << ms_class->name() << " has no " << call_func << " function, please check the code.";
|
||||||
|
}
|
||||||
|
py::object call_obj = py::getattr(cls_obj, common::SafeCStr(call_func));
|
||||||
|
FuncGraphPtr call_func_graph = parse::ConvertToFuncGraph(call_obj);
|
||||||
|
if (call_func_graph == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Parse python object " << call_func << " failed.";
|
||||||
|
}
|
||||||
|
FuncGraphManagerPtr manager = engine->func_graph_manager();
|
||||||
|
manager->AddFuncGraph(call_func_graph);
|
||||||
|
|
||||||
|
// Replace net with net.__call__
|
||||||
|
AnfNodePtr old_node = out_conf->node();
|
||||||
|
MS_EXCEPTION_IF_NULL(old_node);
|
||||||
|
CNodePtr old_cnode = dyn_cast<CNode>(old_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(old_cnode);
|
||||||
|
std::vector<AnfNodePtr> inputs = {NewValueNode(call_func_graph)};
|
||||||
|
for (size_t i = 1; i < old_cnode->size(); i++) {
|
||||||
|
(void)inputs.emplace_back(old_cnode->input(i));
|
||||||
|
}
|
||||||
|
FuncGraphPtr func_graph = out_conf->func_graph();
|
||||||
|
auto new_cnode = func_graph->NewCNode(inputs);
|
||||||
|
// Continue to eval new_cnode.
|
||||||
|
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
|
||||||
|
return engine->ForwardConfig(out_conf, fn_conf);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
||||||
public:
|
public:
|
||||||
PyInterpretEvaluator() : TransitionPrimEvaluator("PyInterpretEvaluator") {}
|
PyInterpretEvaluator() : TransitionPrimEvaluator("PyInterpretEvaluator") {}
|
||||||
|
@ -2085,6 +2203,7 @@ void InitPrimEvaluatorConstructors() {
|
||||||
constructor[prim::kPrimGetAttr] = std::make_shared<GetAttrEvaluator>();
|
constructor[prim::kPrimGetAttr] = std::make_shared<GetAttrEvaluator>();
|
||||||
constructor[prim::kPrimResolve] = std::make_shared<ResolveEvaluator>();
|
constructor[prim::kPrimResolve] = std::make_shared<ResolveEvaluator>();
|
||||||
constructor[prim::kPrimCreateInstance] = std::make_shared<CreateInstanceEvaluator>();
|
constructor[prim::kPrimCreateInstance] = std::make_shared<CreateInstanceEvaluator>();
|
||||||
|
constructor[prim::kPrimCallInstance] = std::make_shared<CallInstanceEvaluator>();
|
||||||
constructor[prim::kPrimPartial] = std::make_shared<PartialEvaluator>();
|
constructor[prim::kPrimPartial] = std::make_shared<PartialEvaluator>();
|
||||||
constructor[prim::kPrimPyInterpret] = std::make_shared<PyInterpretEvaluator>();
|
constructor[prim::kPrimPyInterpret] = std::make_shared<PyInterpretEvaluator>();
|
||||||
constructor[prim::kPrimMakeTuple] = std::make_shared<MakeTupleEvaluator>();
|
constructor[prim::kPrimMakeTuple] = std::make_shared<MakeTupleEvaluator>();
|
||||||
|
|
|
@ -82,7 +82,7 @@ bool CheckAbstractScalar(const AnfNodePtr &node) {
|
||||||
if (abstract->isa<AbstractScalar>()) {
|
if (abstract->isa<AbstractScalar>()) {
|
||||||
TypePtr type = abstract->GetTypeTrack();
|
TypePtr type = abstract->GetTypeTrack();
|
||||||
MS_EXCEPTION_IF_NULL(type);
|
MS_EXCEPTION_IF_NULL(type);
|
||||||
if (type->isa<EnvType>()) {
|
if (type->isa<EnvType>() || type->isa<MsClassType>()) {
|
||||||
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString() << ", node: " << node->DebugString();
|
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString() << ", node: " << node->DebugString();
|
||||||
}
|
}
|
||||||
if (type->isa<Problem>() || type->isa<External>()) {
|
if (type->isa<Problem>() || type->isa<External>()) {
|
||||||
|
|
|
@ -345,6 +345,24 @@ class MS_CORE_API Problem final : public Type {
|
||||||
};
|
};
|
||||||
using ProblemPtr = std::shared_ptr<Problem>;
|
using ProblemPtr = std::shared_ptr<Problem>;
|
||||||
|
|
||||||
|
/// \brief MsClassType defines a type which is ms_class.
|
||||||
|
class MS_CORE_API MsClassType final : public Type {
|
||||||
|
public:
|
||||||
|
/// \brief The constructor of External.
|
||||||
|
///
|
||||||
|
/// \return The instance of External.
|
||||||
|
MsClassType() : Type(kObjectTypeClass) {}
|
||||||
|
|
||||||
|
/// \brief The destructor of External.
|
||||||
|
~MsClassType() override = default;
|
||||||
|
MS_DECLARE_PARENT(MsClassType, Type)
|
||||||
|
|
||||||
|
TypeId generic_type_id() const override { return kObjectTypeClass; }
|
||||||
|
TypePtr DeepCopy() const override { return std::make_shared<MsClassType>(); }
|
||||||
|
std::string DumpText() const override { return "MsClassType"; }
|
||||||
|
};
|
||||||
|
using MsClassTypePtr = std::shared_ptr<MsClassType>;
|
||||||
|
|
||||||
/// \brief External defines a type which is external.
|
/// \brief External defines a type which is external.
|
||||||
class MS_CORE_API External final : public Type {
|
class MS_CORE_API External final : public Type {
|
||||||
public:
|
public:
|
||||||
|
@ -360,9 +378,6 @@ class MS_CORE_API External final : public Type {
|
||||||
TypeId generic_type_id() const override { return kMetaTypeExternal; }
|
TypeId generic_type_id() const override { return kMetaTypeExternal; }
|
||||||
TypePtr DeepCopy() const override { return std::make_shared<External>(); }
|
TypePtr DeepCopy() const override { return std::make_shared<External>(); }
|
||||||
std::string DumpText() const override { return "ExternalType"; }
|
std::string DumpText() const override { return "ExternalType"; }
|
||||||
|
|
||||||
private:
|
|
||||||
TypePtr kind;
|
|
||||||
};
|
};
|
||||||
using ExternalPtr = std::shared_ptr<External>;
|
using ExternalPtr = std::shared_ptr<External>;
|
||||||
|
|
||||||
|
|
|
@ -939,6 +939,7 @@ GVAR_DEF(PrimitivePtr, kPrimResolve, std::make_shared<Primitive>("resolve"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimEmbed, std::make_shared<Primitive>("embed"));
|
GVAR_DEF(PrimitivePtr, kPrimEmbed, std::make_shared<Primitive>("embed"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimRefToEmbed, std::make_shared<Primitive>("RefToEmbed"));
|
GVAR_DEF(PrimitivePtr, kPrimRefToEmbed, std::make_shared<Primitive>("RefToEmbed"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimCreateInstance, std::make_shared<Primitive>("create_instance"));
|
GVAR_DEF(PrimitivePtr, kPrimCreateInstance, std::make_shared<Primitive>("create_instance"));
|
||||||
|
GVAR_DEF(PrimitivePtr, kPrimCallInstance, std::make_shared<Primitive>("call_instance"));
|
||||||
|
|
||||||
// Other miscellaneous
|
// Other miscellaneous
|
||||||
GVAR_DEF(PrimitivePtr, kPrimGetRefOrigin, std::make_shared<Primitive>("get_ref_origin"));
|
GVAR_DEF(PrimitivePtr, kPrimGetRefOrigin, std::make_shared<Primitive>("get_ref_origin"));
|
||||||
|
|
|
@ -31,7 +31,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
|
||||||
"get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary",
|
"get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary",
|
||||||
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
|
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
|
||||||
"InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
|
"InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
|
||||||
"stop_gradient", "UpdateState", "Load", "Switch", "Print"};
|
"stop_gradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
|
||||||
#else
|
#else
|
||||||
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem",
|
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem",
|
||||||
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
|
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
|
||||||
|
@ -40,7 +40,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
|
||||||
"identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key",
|
"identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key",
|
||||||
"get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "Debug", "col2im_v1",
|
"get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "Debug", "col2im_v1",
|
||||||
"resolve", "BroadcastGradientArgs", "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
|
"resolve", "BroadcastGradientArgs", "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
|
||||||
"stop_gradient", "UpdateState", "Load", "Switch", "Print"};
|
"stop_gradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
|
||||||
#endif
|
#endif
|
||||||
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather,
|
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather,
|
||||||
prim::kPrimMicroStepAllGather};
|
prim::kPrimMicroStepAllGather};
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -18,20 +18,18 @@ Interfaces for parser module in c++.
|
||||||
|
|
||||||
from .parser import (Parser, create_instance, is_supported_create_instance_type, generate_scope,
|
from .parser import (Parser, create_instance, is_supported_create_instance_type, generate_scope,
|
||||||
get_bprop_method_of_class, get_class_instance_type, get_class_member_namespace_symbol,
|
get_bprop_method_of_class, get_class_instance_type, get_class_member_namespace_symbol,
|
||||||
create_slice_obj, get_dataclass_attributes, get_dataclass_methods, get_obj_id,
|
create_slice_obj, get_obj_id, get_module_namespace, get_obj_type, get_object_key,
|
||||||
get_module_namespace, get_obj_type, get_object_key, get_ast_type, get_node_type,
|
get_ast_type, get_node_type, get_args, get_args_default_values, get_ast_namespace_symbol,
|
||||||
get_args, get_args_default_values, get_ast_namespace_symbol, get_operation_symbol,
|
get_operation_symbol, get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name,
|
||||||
get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name, eval_script,
|
eval_script, expand_expr_statement, is_class_member, parse_cb, resolve_symbol,
|
||||||
expand_expr_statement, is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor,
|
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
|
||||||
get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
|
is_class_type, get_dataclass_attributes, get_dataclass_methods)
|
||||||
get_ms_class_attr)
|
|
||||||
|
|
||||||
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
|
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
|
||||||
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type',
|
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
|
||||||
'get_args_default_values', 'get_ast_namespace_symbol', 'get_operation_namespace_symbol',
|
'create_slice_obj', 'get_obj_id', 'get_module_namespace', 'get_obj_type', 'get_object_key',
|
||||||
'get_args', 'get_obj_type', 'create_instance', 'is_supported_create_instance_type',
|
'get_ast_type', 'get_node_type', 'get_args', 'get_args_default_values', 'get_ast_namespace_symbol',
|
||||||
'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser',
|
'get_operation_symbol', 'get_operation_namespace_symbol', 'get_parse_method_of_class', 'get_scope_name',
|
||||||
'get_dataclass_attributes', 'get_dataclass_methods', 'get_dataclass_methods', 'get_scope_name',
|
'eval_script', 'expand_expr_statement', 'is_class_member', 'parse_cb', 'resolve_symbol',
|
||||||
'eval_script', 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement',
|
'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
|
||||||
'generate_scope', 'get_operation_symbol', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
|
'is_class_type', 'get_dataclass_attributes', 'get_dataclass_methods']
|
||||||
'get_ms_class_attr']
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||||
#
|
#
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -118,7 +118,11 @@ class ClassMemberNamespace(Namespace):
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise UnboundLocalError(name)
|
raise UnboundLocalError(name)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.")
|
# Check if cls is user-defined class decorated with ms_class. If true, an exception will be thrown.
|
||||||
|
cls = d.__class__
|
||||||
|
if hasattr(cls, '__ms_class__'):
|
||||||
|
raise NotImplementedError(f"'{cls.__name__ }' object has no attribute or method: '{name}'.")
|
||||||
|
logger.info(f"'{cls.__name__ }' object has no attribute or method: '{name}', so will return None.")
|
||||||
raise AttributeError(name)
|
raise AttributeError(name)
|
||||||
|
|
||||||
|
|
||||||
|
@ -142,5 +146,4 @@ class ClassAttrNamespace(Namespace):
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise UnboundLocalError(name)
|
raise UnboundLocalError(name)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.")
|
raise AttributeError(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}'.")
|
||||||
raise AttributeError(name)
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||||
#
|
#
|
||||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -23,7 +23,6 @@ import hashlib
|
||||||
import inspect
|
import inspect
|
||||||
import types
|
import types
|
||||||
import importlib
|
import importlib
|
||||||
from dataclasses import is_dataclass
|
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
|
|
||||||
import asttokens
|
import asttokens
|
||||||
|
@ -324,24 +323,26 @@ def get_class_instance_type(obj):
|
||||||
"""Get the class instance detail type."""
|
"""Get the class instance detail type."""
|
||||||
# check the obj type
|
# check the obj type
|
||||||
logger.debug("Get the class type(%r)", obj)
|
logger.debug("Get the class type(%r)", obj)
|
||||||
class_type = CLASS_INSTANCE_TYPE_INVALID
|
if isinstance(obj, nn.Cell):
|
||||||
if _is_class_instance(obj):
|
return CLASS_INSTANCE_TYPE_CELL
|
||||||
if isinstance(obj, nn.Cell):
|
if isinstance(obj, ops.Primitive):
|
||||||
class_type = CLASS_INSTANCE_TYPE_CELL
|
return CLASS_INSTANCE_TYPE_PRIMITIVE
|
||||||
elif isinstance(obj, ops.Primitive):
|
return CLASS_INSTANCE_TYPE_INVALID
|
||||||
class_type = CLASS_INSTANCE_TYPE_PRIMITIVE
|
|
||||||
# Add the other type base requirement
|
|
||||||
return class_type
|
|
||||||
|
|
||||||
|
|
||||||
def _is_class_instance(obj):
|
def _is_ms_class(obj):
|
||||||
"""Confirm the obj is class instance."""
|
"""Check if obj is ms_class object."""
|
||||||
return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_dataclass_instance(obj)
|
return hasattr(obj, '__ms_class__')
|
||||||
|
|
||||||
|
|
||||||
def _is_dataclass_instance(obj):
|
def _is_dataclass_instance(obj):
|
||||||
"""Check whether a class is an instance of a dataclass (and not a dataclass itself)"""
|
"""Check whether a class is an instance of a dataclass (and not a dataclass itself)"""
|
||||||
return is_dataclass(obj) and not isinstance(obj, type)
|
return hasattr(obj, "__dataclass_fields__") and not isinstance(obj, type)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_class_instance(obj):
|
||||||
|
"""Confirm the obj is class instance."""
|
||||||
|
return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_dataclass_instance(obj) or _is_ms_class(obj)
|
||||||
|
|
||||||
|
|
||||||
def _convert_tuple_to_args_kwargs(params):
|
def _convert_tuple_to_args_kwargs(params):
|
||||||
|
@ -358,7 +359,7 @@ def _convert_tuple_to_args_kwargs(params):
|
||||||
|
|
||||||
def is_supported_create_instance_type(cls_type):
|
def is_supported_create_instance_type(cls_type):
|
||||||
"""Check if cls_type is a supported instance type."""
|
"""Check if cls_type is a supported instance type."""
|
||||||
return issubclass(cls_type, (nn.Cell, ops.Primitive))
|
return issubclass(cls_type, (nn.Cell, ops.Primitive)) or _is_ms_class(cls_type)
|
||||||
|
|
||||||
|
|
||||||
def create_instance(cls_type, params=None):
|
def create_instance(cls_type, params=None):
|
||||||
|
@ -440,28 +441,19 @@ def get_dataclass_methods(cls):
|
||||||
return methods
|
return methods
|
||||||
|
|
||||||
|
|
||||||
|
def is_class_type(cls):
|
||||||
|
"""Check if cls is a class type."""
|
||||||
|
return isinstance(cls, type)
|
||||||
|
|
||||||
|
|
||||||
def get_ms_class_name(cls):
|
def get_ms_class_name(cls):
|
||||||
"""Get the name of the class instance decorated by ms_class."""
|
"""Get the name of the class instance decorated by ms_class."""
|
||||||
# Check if cls is nn.Cell.
|
# Check if cls is nn.Cell.
|
||||||
if isinstance(cls, nn.Cell):
|
if isinstance(cls, nn.Cell):
|
||||||
raise TypeError(f"ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
|
raise TypeError(f"ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
|
||||||
if isinstance(cls, type):
|
if isinstance(cls, type):
|
||||||
name = cls.__name__
|
return cls.__name__
|
||||||
else:
|
return cls.__class__.__name__
|
||||||
name = cls.__class__.__name__
|
|
||||||
# Get the name of cls.
|
|
||||||
cls_name = cls.__module__ + '.' + name
|
|
||||||
return cls_name
|
|
||||||
|
|
||||||
|
|
||||||
def get_ms_class_attr(cls, name: str):
|
|
||||||
"""Get attribute or method of ms_class obj."""
|
|
||||||
# Don't take into account python magic methods and private variables.
|
|
||||||
if name.startswith('_'):
|
|
||||||
raise AttributeError(f"{name} is a private variable or magic method, which is not supported.")
|
|
||||||
if not hasattr(cls, name):
|
|
||||||
raise AttributeError(f"{cls} has no attribute: {name}.")
|
|
||||||
return getattr(cls, name)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_ms_tensor(data):
|
def convert_to_ms_tensor(data):
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||||
#
|
#
|
||||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -15,10 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""standard_method"""
|
"""standard_method"""
|
||||||
|
from mindspore import Tensor, Parameter, CSRTensor, COOTensor, ms_class
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
from mindspore import Tensor, Parameter, CSRTensor, COOTensor
|
|
||||||
from mindspore import dtype as mstype
|
from mindspore import dtype as mstype
|
||||||
|
|
||||||
from ..._checkparam import Validator as validator
|
from ..._checkparam import Validator as validator
|
||||||
|
@ -1828,16 +1825,16 @@ def float_floordiv(x, y):
|
||||||
#############
|
#############
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@ms_class
|
||||||
class SequenceIterator:
|
class SequenceIterator:
|
||||||
"""
|
"""
|
||||||
SequenceIterator is a util dataclass for iterating sequence object.
|
SequenceIterator is a util class for iterating sequence object.
|
||||||
|
|
||||||
Iterator to use for sequences like List, Array.
|
Iterator to use for sequences like List, Array.
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, idx, seq):
|
||||||
idx: int
|
self.idx = idx
|
||||||
seq: list
|
self.seq = seq
|
||||||
|
|
||||||
@core(ignore_values=True)
|
@core(ignore_values=True)
|
||||||
def __ms_hasnext__(self):
|
def __ms_hasnext__(self):
|
||||||
|
|
|
@ -0,0 +1,451 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
""" test graph fallback """
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore import Tensor, context, ms_class
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_fallback_class_attr():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Access the attributes of user-defined classes decorated by ms_class.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
@ms_class
|
||||||
|
class InnerNet:
|
||||||
|
def __init__(self):
|
||||||
|
self.number = Tensor(1, dtype=mstype.int32)
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.inner_net = InnerNet()
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
out = self.inner_net.number
|
||||||
|
return out
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
out = net()
|
||||||
|
assert out.asnumpy() == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_fallback_class_method():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Access the methods of user-defined classes decorated by ms_class.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
@ms_class
|
||||||
|
class InnerNet:
|
||||||
|
def __init__(self):
|
||||||
|
self.val = Tensor(2, dtype=mstype.int32)
|
||||||
|
|
||||||
|
def act(self, x, y):
|
||||||
|
return self.val * (x + y)
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.inner_net = InnerNet()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
out = self.inner_net.act(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = Tensor(2, dtype=mstype.int32)
|
||||||
|
y = Tensor(3, dtype=mstype.int32)
|
||||||
|
net = Net()
|
||||||
|
out = net(x, y)
|
||||||
|
assert out.asnumpy() == 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_fallback_class_call():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Call the __call__ function of user-defined classes decorated by ms_class.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
@ms_class
|
||||||
|
class InnerNet:
|
||||||
|
def __init__(self, val):
|
||||||
|
self.val = val
|
||||||
|
|
||||||
|
def __call__(self, x, y):
|
||||||
|
return self.val * (x + y)
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, val):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.inner_net = InnerNet(val)
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
out = self.inner_net(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
val = Tensor(2, dtype=mstype.int32)
|
||||||
|
x = Tensor(3, dtype=mstype.int32)
|
||||||
|
y = Tensor(4, dtype=mstype.int32)
|
||||||
|
net = Net(val)
|
||||||
|
out = net(x, y)
|
||||||
|
assert out.asnumpy() == 14
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_fallback_class_input_attr():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Access the attributes of user-defined classes decorated by ms_class.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
@ms_class
|
||||||
|
class InnerNet:
|
||||||
|
def __init__(self):
|
||||||
|
self.number = Tensor(np.array([1, 2, 3]))
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.inner_net = net()
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
out = self.inner_net.number
|
||||||
|
return out
|
||||||
|
|
||||||
|
net = Net(InnerNet)
|
||||||
|
out = net()
|
||||||
|
expect_res = np.array([1, 2, 3])
|
||||||
|
assert np.all(out.asnumpy() == expect_res)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_fallback_class_input_method():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Access the methods of user-defined classes decorated by ms_class.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
@ms_class
|
||||||
|
class InnerNet:
|
||||||
|
def __init__(self):
|
||||||
|
self.val = Tensor(2, dtype=mstype.int32)
|
||||||
|
|
||||||
|
def act(self, x, y):
|
||||||
|
return self.val * (x + y)
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.inner_net = net()
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
out = self.inner_net.act(1, 2)
|
||||||
|
return out
|
||||||
|
|
||||||
|
net = Net(InnerNet)
|
||||||
|
out = net()
|
||||||
|
assert out.asnumpy() == 6
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_fallback_class_class_nested():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Test nested ms_class in graph.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
@ms_class
|
||||||
|
class Inner:
|
||||||
|
def __init__(self):
|
||||||
|
self.number = Tensor(1, dtype=mstype.int32)
|
||||||
|
|
||||||
|
@ms_class
|
||||||
|
class InnerNet:
|
||||||
|
def __init__(self):
|
||||||
|
self.inner = Inner()
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.inner_net = InnerNet()
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
out = self.inner_net.inner.number
|
||||||
|
return out
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
out = net()
|
||||||
|
assert out.asnumpy() == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_fallback_class_cell_nested():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Test nested ms_class and cell in graph.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, val):
|
||||||
|
super().__init__()
|
||||||
|
self.val = val
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return x + self.val
|
||||||
|
|
||||||
|
@ms_class
|
||||||
|
class TrainNet():
|
||||||
|
class Loss(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super().__init__()
|
||||||
|
self.net = net
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
out = self.net(x)
|
||||||
|
return out * 2
|
||||||
|
|
||||||
|
def __init__(self, net):
|
||||||
|
self.net = net
|
||||||
|
loss_net = self.Loss(self.net)
|
||||||
|
self.number = loss_net(10)
|
||||||
|
|
||||||
|
global_net = Net(1)
|
||||||
|
class LearnNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.value = TrainNet(global_net).number
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return x + self.value
|
||||||
|
|
||||||
|
leanrn_net = LearnNet()
|
||||||
|
out = leanrn_net(3)
|
||||||
|
print(out)
|
||||||
|
assert out == 25
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_fallback_class_type_attr():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Access the attributes of class type.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
@ms_class
|
||||||
|
class InnerNet:
|
||||||
|
val = Tensor(2, dtype=mstype.int32)
|
||||||
|
|
||||||
|
def act(self, x, y):
|
||||||
|
return self.val * (x + y)
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.inner_net = InnerNet
|
||||||
|
|
||||||
|
# Support accessing attributes of class type, but do not support
|
||||||
|
# accessing methods, e.g. self.inner_net.act(1, 2)
|
||||||
|
def construct(self):
|
||||||
|
out = self.inner_net.val
|
||||||
|
return out
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
out = net()
|
||||||
|
assert out == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_fallback_class_create_instance_attr():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Access the attributes of the created class instance.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
@ms_class
|
||||||
|
class InnerNet:
|
||||||
|
def __init__(self, val):
|
||||||
|
self.number = val + 3
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.inner_net = InnerNet
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
net = self.inner_net(x)
|
||||||
|
return net.number
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
out = net(2)
|
||||||
|
assert out == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_fallback_class_create_instance_method():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Access the methods of the created class instance.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
@ms_class
|
||||||
|
class InnerNet:
|
||||||
|
def __init__(self, val):
|
||||||
|
self.number = val
|
||||||
|
|
||||||
|
def act(self, x, y):
|
||||||
|
return self.number * (x + y)
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.inner_net = InnerNet
|
||||||
|
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
net = self.inner_net(x)
|
||||||
|
return net.act(y, z)
|
||||||
|
|
||||||
|
x = 2
|
||||||
|
y = Tensor(2, dtype=mstype.int32)
|
||||||
|
z = Tensor(3, dtype=mstype.int32)
|
||||||
|
net = Net()
|
||||||
|
out = net(x, y, z)
|
||||||
|
assert out.asnumpy() == 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_fallback_class_create_instance_call():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Call the __call__ function of the created class instance.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
@ms_class
|
||||||
|
class InnerNet:
|
||||||
|
def __init__(self, number):
|
||||||
|
self.number = number
|
||||||
|
|
||||||
|
def __call__(self, x, y):
|
||||||
|
return self.number * (x + y)
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.inner_net = InnerNet
|
||||||
|
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
net = self.inner_net(x)
|
||||||
|
out = net(y, z)
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = 2
|
||||||
|
y = Tensor(2, dtype=mstype.int32)
|
||||||
|
z = Tensor(3, dtype=mstype.int32)
|
||||||
|
net = Net()
|
||||||
|
out = net(x, y, z)
|
||||||
|
assert out == 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_fallback_raise_error_not_class_type():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Decorator ms_class cannot be used for non-class types.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
@ms_class
|
||||||
|
def func(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
func(1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_fallback_raise_error_decorate_cell():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Decorator ms_class cannot be used for nn.Cell
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
@ms_class
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def construct(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
x = Tensor(1)
|
||||||
|
net = Net()
|
||||||
|
net(x)
|
|
@ -1,416 +0,0 @@
|
||||||
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
""" test graph fallback """
|
|
||||||
import pytest
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mindspore.nn as nn
|
|
||||||
import mindspore.common.dtype as mstype
|
|
||||||
from mindspore import Tensor, context, ms_class, ms_function
|
|
||||||
from . import test_graph_fallback
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
|
||||||
|
|
||||||
|
|
||||||
def test_fallback_self_attr():
|
|
||||||
"""
|
|
||||||
Feature: JIT Fallback
|
|
||||||
Description: Test self.attr in graph.
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
class Network(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(Network, self).__init__()
|
|
||||||
self.dim = 1
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
batch = x.shape[0]
|
|
||||||
one = Tensor(np.ones([batch, self.dim]), mstype.float32)
|
|
||||||
return one * x
|
|
||||||
|
|
||||||
net = Network()
|
|
||||||
x = Tensor([1, 2], mstype.float32)
|
|
||||||
out = net(x)
|
|
||||||
expect = np.array([[1., 2.], [1., 2.]])
|
|
||||||
assert np.allclose(out.asnumpy(), expect, 1.e-2, 1.e-2)
|
|
||||||
|
|
||||||
|
|
||||||
def test_fallback_self_attr_fn():
|
|
||||||
"""
|
|
||||||
Feature: JIT Fallback
|
|
||||||
Description: Test self.attr in graph.
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
class Network(nn.Cell):
|
|
||||||
def __init__(self, fn):
|
|
||||||
super(Network, self).__init__()
|
|
||||||
self.fn = fn
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
x = np.array([1, 2, 3])
|
|
||||||
y = np.array([3, 4, 5])
|
|
||||||
out = Tensor(self.fn(x, y))
|
|
||||||
return out
|
|
||||||
|
|
||||||
def fn(x, y):
|
|
||||||
return x + y
|
|
||||||
|
|
||||||
net = Network(fn)
|
|
||||||
out = net()
|
|
||||||
expect = np.array([4, 6, 8])
|
|
||||||
assert np.all(out.asnumpy() == expect)
|
|
||||||
|
|
||||||
|
|
||||||
def test_fallback_self_attr_attr():
|
|
||||||
"""
|
|
||||||
Feature: JIT Fallback
|
|
||||||
Description: Test self.attr in graph.
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
class Network(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(Network, self).__init__()
|
|
||||||
self.value = [2, 2, 3]
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
x = np.array(self.value.count(2))
|
|
||||||
return Tensor(x)
|
|
||||||
|
|
||||||
net = Network()
|
|
||||||
out = net()
|
|
||||||
assert out == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_fallback_self_method():
|
|
||||||
"""
|
|
||||||
Feature: JIT Fallback
|
|
||||||
Description: Test self.method in graph.
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
class Network(nn.Cell):
|
|
||||||
def construct(self):
|
|
||||||
x = np.array([1, 2, 3])
|
|
||||||
y = np.array([3, 4, 5])
|
|
||||||
out = Tensor(self.fn(x, y))
|
|
||||||
return out
|
|
||||||
|
|
||||||
def fn(self, x, y):
|
|
||||||
return x + y
|
|
||||||
|
|
||||||
net = Network()
|
|
||||||
out = net()
|
|
||||||
expect = np.array([4, 6, 8])
|
|
||||||
assert np.all(out.asnumpy() == expect)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
|
|
||||||
def test_fallback_self_method_tensor():
|
|
||||||
"""
|
|
||||||
Feature: JIT Fallback
|
|
||||||
Description: Test self.method in graph.
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
class Network(nn.Cell):
|
|
||||||
def construct(self):
|
|
||||||
x = np.array([1, 2, 3])
|
|
||||||
y = np.array([3, 4, 5])
|
|
||||||
z = self.fn(x, y)
|
|
||||||
out = Tensor(z)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def fn(self, x, y):
|
|
||||||
return x + y
|
|
||||||
|
|
||||||
net = Network()
|
|
||||||
out = net()
|
|
||||||
print(out)
|
|
||||||
|
|
||||||
|
|
||||||
def test_fallback_import_modules():
|
|
||||||
"""
|
|
||||||
Feature: JIT Fallback
|
|
||||||
Description: add_func is defined in test_graph_fallback.py
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
@ms_function
|
|
||||||
def use_imported_module(x, y):
|
|
||||||
out = test_graph_fallback.add_func(x, y)
|
|
||||||
return out
|
|
||||||
|
|
||||||
x = Tensor(2, dtype=mstype.int32)
|
|
||||||
y = Tensor(3, dtype=mstype.int32)
|
|
||||||
out = use_imported_module(x, y)
|
|
||||||
print(out)
|
|
||||||
|
|
||||||
|
|
||||||
def test_fallback_class_attr():
|
|
||||||
"""
|
|
||||||
Feature: JIT Fallback
|
|
||||||
Description: Test user-defined class attributes in graph.
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
@ms_class
|
|
||||||
class InnerNet:
|
|
||||||
def __init__(self):
|
|
||||||
self.number = 1
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
self.inner_net = InnerNet()
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
out = self.inner_net.number
|
|
||||||
return out
|
|
||||||
|
|
||||||
net = Net()
|
|
||||||
out = net()
|
|
||||||
assert out == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_fallback_class_method():
|
|
||||||
"""
|
|
||||||
Feature: JIT Fallback
|
|
||||||
Description: Test user-defined class methods in graph.
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
@ms_class
|
|
||||||
class InnerNet:
|
|
||||||
def __init__(self):
|
|
||||||
self.val = 2
|
|
||||||
|
|
||||||
def act(self, x, y):
|
|
||||||
return self.val * (x + y)
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
self.inner_net = InnerNet()
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
out = self.inner_net.act(1, 2)
|
|
||||||
return out
|
|
||||||
|
|
||||||
net = Net()
|
|
||||||
out = net()
|
|
||||||
assert out == 6
|
|
||||||
|
|
||||||
|
|
||||||
def test_fallback_class_input_attr():
|
|
||||||
"""
|
|
||||||
Feature: JIT Fallback
|
|
||||||
Description: Test user-defined class attributes in graph.
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
@ms_class
|
|
||||||
class InnerNet:
|
|
||||||
def __init__(self):
|
|
||||||
self.number = Tensor(np.array([1, 2, 3]))
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self, net):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
self.inner_net = net()
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
out = self.inner_net.number
|
|
||||||
return out
|
|
||||||
|
|
||||||
net = Net(InnerNet)
|
|
||||||
out = net()
|
|
||||||
expect_res = np.array([1, 2, 3])
|
|
||||||
assert np.all(out.asnumpy() == expect_res)
|
|
||||||
|
|
||||||
|
|
||||||
def test_fallback_class_input_method():
|
|
||||||
"""
|
|
||||||
Feature: JIT Fallback
|
|
||||||
Description: Test user-defined class methods in graph.
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
@ms_class
|
|
||||||
class InnerNet:
|
|
||||||
def __init__(self):
|
|
||||||
self.val = 2
|
|
||||||
|
|
||||||
def act(self, x, y):
|
|
||||||
return self.val * (x + y)
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self, net):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
self.inner_net = net()
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
out = self.inner_net.act(1, 2)
|
|
||||||
return out
|
|
||||||
|
|
||||||
net = Net(InnerNet)
|
|
||||||
out = net()
|
|
||||||
assert out == 6
|
|
||||||
|
|
||||||
|
|
||||||
def test_fallback_class_class_nested():
|
|
||||||
"""
|
|
||||||
Feature: JIT Fallback
|
|
||||||
Description: Test nested ms_class in graph.
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
@ms_class
|
|
||||||
class Inner:
|
|
||||||
def __init__(self):
|
|
||||||
self.number = 1
|
|
||||||
|
|
||||||
@ms_class
|
|
||||||
class InnerNet:
|
|
||||||
def __init__(self):
|
|
||||||
self.inner = Inner()
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
self.inner_net = InnerNet()
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
out = self.inner_net.inner.number
|
|
||||||
return out
|
|
||||||
|
|
||||||
net = Net()
|
|
||||||
out = net()
|
|
||||||
assert out == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_fallback_class_cell_nested():
|
|
||||||
"""
|
|
||||||
Feature: JIT Fallback
|
|
||||||
Description: Test nested ms_class and cell in graph.
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self, val):
|
|
||||||
super().__init__()
|
|
||||||
self.val = val
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
return x + self.val
|
|
||||||
|
|
||||||
@ms_class
|
|
||||||
class TrainNet():
|
|
||||||
class Loss(nn.Cell):
|
|
||||||
def __init__(self, net):
|
|
||||||
super().__init__()
|
|
||||||
self.net = net
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
out = self.net(x)
|
|
||||||
return out * 2
|
|
||||||
|
|
||||||
def __init__(self, net):
|
|
||||||
self.net = net
|
|
||||||
loss_net = self.Loss(self.net)
|
|
||||||
self.number = loss_net(10)
|
|
||||||
|
|
||||||
global_net = Net(1)
|
|
||||||
class LearnNet(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.value = TrainNet(global_net).number
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
return x + self.value
|
|
||||||
|
|
||||||
leanrn_net = LearnNet()
|
|
||||||
out = leanrn_net(3)
|
|
||||||
print(out)
|
|
||||||
assert out == 25
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason='Not support in graph yet')
|
|
||||||
def test_fallback_class_isinstance():
|
|
||||||
"""
|
|
||||||
Feature: JIT Fallback
|
|
||||||
Description: Test ms_class in graph.
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
@ms_class
|
|
||||||
class InnerNet:
|
|
||||||
def __init__(self):
|
|
||||||
self.number = 1
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
self.inner_net = InnerNet()
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
if isinstance(self.inner_net, InnerNet):
|
|
||||||
return x + 10
|
|
||||||
return x
|
|
||||||
|
|
||||||
net = Net()
|
|
||||||
out = net(5)
|
|
||||||
assert out == 15
|
|
||||||
|
|
||||||
|
|
||||||
def test_fallback_raise_error_not_class_type():
|
|
||||||
"""
|
|
||||||
Feature: JIT Fallback
|
|
||||||
Description: Test ms_class in graph.
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
with pytest.raises(TypeError):
|
|
||||||
@ms_class
|
|
||||||
def func(x, y):
|
|
||||||
return x + y
|
|
||||||
|
|
||||||
func(1, 2)
|
|
||||||
|
|
||||||
|
|
||||||
def test_fallback_raise_error_not_class_instance():
|
|
||||||
"""
|
|
||||||
Feature: JIT Fallback
|
|
||||||
Description: Test ms_class in graph.
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
@ms_class
|
|
||||||
class InnerNet:
|
|
||||||
def __init__(self):
|
|
||||||
self.number = 1
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def construct(self):
|
|
||||||
out = InnerNet().number
|
|
||||||
return out
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
net = Net()
|
|
||||||
net()
|
|
||||||
|
|
||||||
|
|
||||||
def test_fallback_raise_error_decorate_cell():
|
|
||||||
"""
|
|
||||||
Feature: JIT Fallback
|
|
||||||
Description: Test ms_class in graph.
|
|
||||||
Expectation: No exception.
|
|
||||||
"""
|
|
||||||
@ms_class
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def construct(self, x):
|
|
||||||
return x
|
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
|
||||||
x = Tensor(1)
|
|
||||||
net = Net()
|
|
||||||
net(x)
|
|
|
@ -0,0 +1,131 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
""" test graph fallback """
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore import Tensor, context, ms_function
|
||||||
|
from . import test_graph_fallback
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallback_self_attr():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Use self.attr in expressions supported by JIT Fallback.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
class Network(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Network, self).__init__()
|
||||||
|
self.dim = 1
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
batch = x.shape[0]
|
||||||
|
one = Tensor(np.ones([batch, self.dim]), mstype.float32)
|
||||||
|
return one * x
|
||||||
|
|
||||||
|
net = Network()
|
||||||
|
x = Tensor([1, 2], mstype.float32)
|
||||||
|
out = net(x)
|
||||||
|
expect = np.array([[1., 2.], [1., 2.]])
|
||||||
|
assert np.allclose(out.asnumpy(), expect, 1.e-2, 1.e-2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallback_self_attr_fn():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Use self.attr of type function in expressions supported by JIT Fallback.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
class Network(nn.Cell):
|
||||||
|
def __init__(self, fn):
|
||||||
|
super(Network, self).__init__()
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
x = np.array([1, 2, 3])
|
||||||
|
y = np.array([3, 4, 5])
|
||||||
|
out = Tensor(self.fn(x, y))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def fn(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
net = Network(fn)
|
||||||
|
out = net()
|
||||||
|
expect = np.array([4, 6, 8])
|
||||||
|
assert np.all(out.asnumpy() == expect)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallback_self_attr_attr():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: In expressions supported by JIT Fallback, use the attribute of self.attr.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
class Network(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Network, self).__init__()
|
||||||
|
self.value = [2, 2, 3]
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
x = np.array(self.value.count(2))
|
||||||
|
return Tensor(x)
|
||||||
|
|
||||||
|
net = Network()
|
||||||
|
out = net()
|
||||||
|
assert out == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallback_self_method():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Use self.method in expressions supported by JIT Fallback.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
class Network(nn.Cell):
|
||||||
|
def construct(self):
|
||||||
|
x = np.array([1, 2, 3])
|
||||||
|
y = np.array([3, 4, 5])
|
||||||
|
out = Tensor(self.fn(x, y))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def fn(self, x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
net = Network()
|
||||||
|
out = net()
|
||||||
|
expect = np.array([4, 6, 8])
|
||||||
|
assert np.all(out.asnumpy() == expect)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallback_import_modules():
|
||||||
|
"""
|
||||||
|
Feature: JIT Fallback
|
||||||
|
Description: Check whether the call to the third-party library is correct. It has nothing to do with class.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
@ms_function
|
||||||
|
def use_imported_module(x, y):
|
||||||
|
out = test_graph_fallback.add_func(x, y)
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = Tensor(2, dtype=mstype.int32)
|
||||||
|
y = Tensor(3, dtype=mstype.int32)
|
||||||
|
out = use_imported_module(x, y)
|
||||||
|
print(out)
|
Loading…
Reference in New Issue