!32613 Supports creating and calling instances of ms_class
Merge pull request !32613 from huangbingjian/class_dev
This commit is contained in:
commit
4c3faa7f8f
|
@ -658,13 +658,19 @@ ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) {
|
|||
return class_type;
|
||||
}
|
||||
|
||||
// Check the object is Cell Instance.
|
||||
// Check if the object is Cell instance.
|
||||
bool IsCellInstance(const py::object &obj) {
|
||||
auto class_type = GetClassInstanceType(obj);
|
||||
bool is_cell = (class_type == CLASS_INSTANCE_TYPE_CELL);
|
||||
return is_cell;
|
||||
}
|
||||
|
||||
// Check if the object is class type.
|
||||
bool IsClassType(const py::object &obj) {
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
return python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_CLASS_TYPE, obj).cast<bool>();
|
||||
}
|
||||
|
||||
// Create the python class instance.
|
||||
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs) {
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -44,6 +44,7 @@ ResolveTypeDef GetObjType(const py::object &obj);
|
|||
ClassInstanceTypeDef GetClassInstanceType(const py::object &obj);
|
||||
|
||||
bool IsCellInstance(const py::object &obj);
|
||||
bool IsClassType(const py::object &obj);
|
||||
py::object CreatePythonObject(const py::object &type, const py::tuple &args_kwargs);
|
||||
py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs);
|
||||
void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name);
|
||||
|
|
|
@ -67,8 +67,8 @@ const char PYTHON_MOD_CREATE_INSTANCE[] = "create_instance";
|
|||
const char PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE[] = "is_supported_create_instance_type";
|
||||
const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes";
|
||||
const char PYTHON_MOD_GET_DATACLASS_METHODS[] = "get_dataclass_methods";
|
||||
const char PYTHON_MOD_IS_CLASS_TYPE[] = "is_class_type";
|
||||
const char PYTHON_MOD_GET_MS_CLASS_NAME[] = "get_ms_class_name";
|
||||
const char PYTHON_MOD_GET_MS_CLASS_ATTR[] = "get_ms_class_attr";
|
||||
const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
|
||||
const char PYTHON_MOD_GET_ATTR_NAMESPACE_SYMBOL[] = "get_class_attr_namespace_symbol";
|
||||
const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol";
|
||||
|
|
|
@ -67,6 +67,7 @@ struct AnfDumpHandlerRegister {
|
|||
}
|
||||
} callback_register;
|
||||
} // namespace
|
||||
|
||||
abstract::AbstractBasePtr ClassObject::ToAbstract() {
|
||||
ClassPtr cls_ptr = ParseDataClass(obj());
|
||||
auto abs_scalar = std::make_shared<abstract::AbstractScalar>();
|
||||
|
@ -78,6 +79,24 @@ abstract::AbstractBasePtr ClassObject::ToAbstract() {
|
|||
return std::make_shared<abstract::PartialAbstractClosure>(func_ptr, args_spec_list);
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr MsClassObject::ToAbstract() {
|
||||
auto abs_scalar =
|
||||
std::make_shared<abstract::AbstractScalar>(shared_from_base<MsClassObject>(), std::make_shared<MsClassType>());
|
||||
AbstractBasePtrList args_spec_list = {abs_scalar};
|
||||
abstract::PrimitiveAbstractClosurePtr func_ptr = nullptr;
|
||||
bool is_class_type = parse::data_converter::IsClassType(obj());
|
||||
if (is_class_type) {
|
||||
// Class type as func, such as Net(x, y)
|
||||
func_ptr = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimCreateInstance);
|
||||
} else {
|
||||
// Class instance as func, such as net(x, y)
|
||||
func_ptr = std::make_shared<abstract::PrimitiveAbstractClosure>(prim::kPrimCallInstance);
|
||||
}
|
||||
auto ret_val = std::make_shared<abstract::PartialAbstractClosure>(func_ptr, args_spec_list);
|
||||
ret_val->set_value_desc(ToString());
|
||||
return ret_val;
|
||||
}
|
||||
|
||||
static inline bool IsSupportedCreateInstanceType(const py::object &obj) {
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
auto res = python_adapter::CallPyModFn(mod, PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE, obj);
|
||||
|
@ -520,14 +539,15 @@ AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const MsCl
|
|||
MS_LOG(DEBUG) << "Resolve ms_class obj (" << ms_class->name() << ") with attr " << attr << ".";
|
||||
TraceGuard trace_guard(std::make_shared<TraceResolve>(node->debug_info()));
|
||||
|
||||
constexpr size_t prefix_index = 0;
|
||||
if (attr.size() > 0 && attr[prefix_index] == '_') {
|
||||
MS_LOG(EXCEPTION) << attr << " is a private variable or magic method, which is not supported.";
|
||||
}
|
||||
py::object cls_obj = ms_class->obj();
|
||||
if (!py::hasattr(cls_obj, attr.c_str())) {
|
||||
if (!py::hasattr(cls_obj, common::SafeCStr(attr))) {
|
||||
MS_LOG(EXCEPTION) << ms_class->name() << " has not attribute: " << attr << ".";
|
||||
}
|
||||
|
||||
const std::string fn = PYTHON_MOD_GET_MS_CLASS_ATTR;
|
||||
const std::string module = "mindspore._extends.parse.parser";
|
||||
py::object attr_obj = python_adapter::GetPyFn(module, fn)(cls_obj, attr);
|
||||
py::object attr_obj = py::getattr(cls_obj, common::SafeCStr(attr));
|
||||
AnfNodePtr res_node = ResolveObjectAndAddToManager(manager, attr_obj, node);
|
||||
TraceManager::ClearParseOrResolveDebugInfo();
|
||||
return res_node;
|
||||
|
|
|
@ -116,6 +116,7 @@ class PyObjectWrapper : public Named {
|
|||
// the object that needs to be resolved
|
||||
py::object obj_;
|
||||
};
|
||||
using PyObjectWrapperPtr = std::shared_ptr<PyObjectWrapper>;
|
||||
|
||||
// InterpretedObject class wrappers interpreted python object.
|
||||
class InterpretedObject final : public PyObjectWrapper {
|
||||
|
@ -137,9 +138,7 @@ class MsClassObject final : public PyObjectWrapper {
|
|||
: PyObjectWrapper(obj, "MsClassObject: \'" + name + "\'") {}
|
||||
~MsClassObject() override = default;
|
||||
MS_DECLARE_PARENT(MsClassObject, PyObjectWrapper);
|
||||
abstract::AbstractBasePtr ToAbstract() override {
|
||||
return std::make_shared<abstract::AbstractScalar>(shared_from_base<MsClassObject>(), std::make_shared<External>());
|
||||
}
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
};
|
||||
using MsClassObjectPtr = std::shared_ptr<MsClassObject>;
|
||||
|
||||
|
|
|
@ -1304,6 +1304,35 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng
|
|||
return StaticGetterInferred(converted_value, data_conf, out_conf);
|
||||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value,
|
||||
const ValuePtr &data_value, const ConfigPtr &data_conf,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
MS_EXCEPTION_IF_NULL(item_value);
|
||||
MS_EXCEPTION_IF_NULL(data_value);
|
||||
// Get the name of item.
|
||||
if (!item_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << "Expect a string, but got: " << item_value->ToString();
|
||||
}
|
||||
std::string item_name = item_value->cast<StringImmPtr>()->value();
|
||||
// Get ms_class object.
|
||||
if (!data_value->isa<parse::MsClassObject>()) {
|
||||
MS_LOG(EXCEPTION) << "Expect a ms_class object, but got " << data_value->ToString();
|
||||
}
|
||||
auto ms_class = data_value->cast<parse::MsClassObjectPtr>();
|
||||
MS_LOG(DEBUG) << "Resolve ms_class (" << ms_class->name() << ") with item " << item_name << ".";
|
||||
|
||||
// Get the attr/method of ms_class object.
|
||||
auto out_node = out_conf->node();
|
||||
FuncGraphPtr func_graph = out_node->func_graph();
|
||||
auto new_node = ResolveMsClassWithAttr(func_graph->manager(), ms_class, item_name, out_node);
|
||||
// Replace old node with the resolved new node in order list.
|
||||
func_graph->ReplaceInOrder(out_node, new_node);
|
||||
AnalysisEnginePtr eng = out_conf->engine();
|
||||
MS_EXCEPTION_IF_NULL(eng);
|
||||
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
|
||||
return eng->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value,
|
||||
const TypePtr &data_type, const ConfigPtr &data_conf,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
|
@ -1363,17 +1392,45 @@ int64_t GetResolveType(const TypePtr &data_type) {
|
|||
return kResolveTypeFunction;
|
||||
}
|
||||
|
||||
ValuePtr GetMsClassObject(const AbstractBasePtr &abs) {
|
||||
if (!abs->isa<abstract::PartialAbstractClosure>()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto partial_abs = abs->cast<abstract::PartialAbstractClosurePtr>();
|
||||
auto fn = partial_abs->fn();
|
||||
if (!fn->isa<abstract::PrimitiveAbstractClosure>()) {
|
||||
return nullptr;
|
||||
}
|
||||
// Check if type is kObjectTypeClass.
|
||||
auto args = partial_abs->args();
|
||||
if (args.size() > 0) {
|
||||
constexpr size_t first_input_index = 0;
|
||||
auto first_arg = args[first_input_index];
|
||||
MS_EXCEPTION_IF_NULL(first_arg);
|
||||
auto type = first_arg->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
if (type->type_id() == kObjectTypeClass) {
|
||||
return first_arg->BuildValue();
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
|
||||
// Inputs: namespace and its static function; or class and its member function
|
||||
CheckArgsSize("StaticGetter", args_spec_list, 2);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
|
||||
MS_LOG(DEBUG) << "Args[0]: " << args_spec_list[0]->ToString();
|
||||
MS_LOG(DEBUG) << "Args[1]: " << args_spec_list[1]->ToString();
|
||||
TypePtr data_type = args_spec_list[0]->BuildType();
|
||||
ValuePtr item_value = args_spec_list[1]->BuildValue();
|
||||
constexpr size_t data_index = 0;
|
||||
constexpr size_t item_index = 1;
|
||||
auto data_args = args_spec_list[data_index];
|
||||
auto item_args = args_spec_list[item_index];
|
||||
MS_EXCEPTION_IF_NULL(data_args);
|
||||
MS_EXCEPTION_IF_NULL(item_args);
|
||||
MS_LOG(DEBUG) << "StaticGetter, data: " << data_args->ToString() << ", item: " << item_args->ToString();
|
||||
TypePtr data_type = data_args->BuildType();
|
||||
ValuePtr item_value = item_args->BuildValue();
|
||||
|
||||
ScopePtr scope = kDefaultScope;
|
||||
if (out_conf != nullptr) {
|
||||
scope = out_conf->node()->scope();
|
||||
|
@ -1384,6 +1441,10 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
|
|||
MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
|
||||
}
|
||||
|
||||
auto class_value = GetMsClassObject(data_args);
|
||||
if (class_value != nullptr) {
|
||||
return GetEvaluatedValueForMsClassAttrOrMethod(engine, item_value, class_value, data_conf, out_conf);
|
||||
}
|
||||
int64_t resolve_type = GetResolveType(data_type);
|
||||
if (resolve_type == kResolveTypeUserDefineClass) {
|
||||
return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf);
|
||||
|
@ -1581,46 +1642,47 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
|
||||
const AnfNodeConfigPtr &out_conf) override {
|
||||
// Check the type parameter.
|
||||
if (args_spec_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
|
||||
}
|
||||
|
||||
// Get the type parameter.
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
TypePtr type = args_spec_list[0]->GetTypeTrack();
|
||||
constexpr size_t type_index = 0;
|
||||
auto arg_class_type = args_spec_list[type_index];
|
||||
MS_EXCEPTION_IF_NULL(arg_class_type);
|
||||
TypePtr type = arg_class_type->GetTypeTrack();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
if (type->type_id() != kMetaTypeTypeType) {
|
||||
MS_LOG(EXCEPTION) << "CreateInstanceEvaluator require first parameter should be an object of TypeType, but got "
|
||||
<< type->ToString();
|
||||
if (type->type_id() != kMetaTypeTypeType && type->type_id() != kObjectTypeClass) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "CreateInstanceEvaluator require first parameter should be an object of TypeType or TypeClass, but got "
|
||||
<< type->ToString();
|
||||
}
|
||||
|
||||
ValuePtr value_track = args_spec_list[0]->GetValueTrack();
|
||||
ValuePtr value_track = arg_class_type->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
|
||||
std::shared_ptr<parse::PyObjectWrapper> type_obj = dyn_cast<parse::PyObjectWrapper>(value_track);
|
||||
parse::PyObjectWrapperPtr type_obj = dyn_cast<parse::PyObjectWrapper>(value_track);
|
||||
if (type_obj == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
|
||||
}
|
||||
|
||||
if (!type_obj->isa<parse::ClassType>()) {
|
||||
MS_LOG(EXCEPTION) << "CreateInstanceEvaluator the type_obj should be an object of ClassType, but got "
|
||||
<< type_obj->ToString() << ".";
|
||||
if (!type_obj->isa<parse::ClassType>() && !type_obj->isa<parse::MsClassObject>()) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "CreateInstanceEvaluator the type_obj should be an object of ClassType or MsClassObject, but got "
|
||||
<< type_obj->ToString() << ".";
|
||||
}
|
||||
|
||||
auto class_type = type_obj->obj();
|
||||
MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << ".";
|
||||
MS_LOG(DEBUG) << "Get class type: " << type_obj->ToString() << ".";
|
||||
|
||||
// Get the create instance obj's parameters, `params` may contain tuple(args, kwargs).
|
||||
py::tuple params = GetParameters(args_spec_list);
|
||||
|
||||
// Create class instance.
|
||||
auto obj = parse::data_converter::CreatePythonObject(class_type, params);
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
MS_LOG(EXCEPTION) << "Create python object `" << py::str(class_type)
|
||||
<< "` failed, only support to create \'Cell\' or \'Primitive\' object.";
|
||||
<< "` failed, only support to create \'Cell\', \'Primitive\' or "
|
||||
<< "user-defined Class decorated with \'ms_class\'.";
|
||||
}
|
||||
|
||||
// Process the object.
|
||||
MS_EXCEPTION_IF_NULL(out_conf->node());
|
||||
TraceGuard guard(std::make_shared<TraceResolve>(out_conf->node()->debug_info()));
|
||||
ValuePtr converted_ret = nullptr;
|
||||
bool converted = parse::ConvertData(obj, &converted_ret, true);
|
||||
|
@ -1628,7 +1690,6 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
MS_LOG(EXCEPTION) << "Convert the python object failed";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(converted_ret);
|
||||
|
||||
if (converted_ret->isa<FuncGraph>()) {
|
||||
AddToManager(engine, converted_ret->cast<FuncGraphPtr>());
|
||||
}
|
||||
|
@ -1664,6 +1725,63 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
}
|
||||
};
|
||||
|
||||
class CallInstanceEvaluator : public TransitionPrimEvaluator {
|
||||
public:
|
||||
CallInstanceEvaluator() : TransitionPrimEvaluator("CallInstanceEvaluator") {}
|
||||
~CallInstanceEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(CallInstanceEvaluator, TransitionPrimEvaluator);
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
|
||||
const AnfNodeConfigPtr &out_conf) override {
|
||||
if (args_spec_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "args_spec_list should not be empty.";
|
||||
}
|
||||
constexpr size_t cls_index = 0;
|
||||
auto arg_cls = args_spec_list[cls_index];
|
||||
MS_EXCEPTION_IF_NULL(arg_cls);
|
||||
TypePtr type = arg_cls->GetTypeTrack();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
if (type->type_id() != kObjectTypeClass) {
|
||||
MS_LOG(EXCEPTION) << "CallInstanceEvaluator require first parameter should be an object of TypeClass, but got "
|
||||
<< type->ToString();
|
||||
}
|
||||
ValuePtr value_track = arg_cls->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
parse::MsClassObjectPtr ms_class = dyn_cast<parse::MsClassObject>(value_track);
|
||||
if (ms_class == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "CallInstanceEvaluator only supports MsClassObject.";
|
||||
}
|
||||
|
||||
// Call class instance, net(x, y) -> net.__call__(x, y)
|
||||
py::object cls_obj = ms_class->obj();
|
||||
const std::string call_func = "__call__";
|
||||
if (!py::hasattr(cls_obj, common::SafeCStr(call_func))) {
|
||||
MS_LOG(EXCEPTION) << ms_class->name() << " has no " << call_func << " function, please check the code.";
|
||||
}
|
||||
py::object call_obj = py::getattr(cls_obj, common::SafeCStr(call_func));
|
||||
FuncGraphPtr call_func_graph = parse::ConvertToFuncGraph(call_obj);
|
||||
if (call_func_graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Parse python object " << call_func << " failed.";
|
||||
}
|
||||
FuncGraphManagerPtr manager = engine->func_graph_manager();
|
||||
manager->AddFuncGraph(call_func_graph);
|
||||
|
||||
// Replace net with net.__call__
|
||||
AnfNodePtr old_node = out_conf->node();
|
||||
MS_EXCEPTION_IF_NULL(old_node);
|
||||
CNodePtr old_cnode = dyn_cast<CNode>(old_node);
|
||||
MS_EXCEPTION_IF_NULL(old_cnode);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(call_func_graph)};
|
||||
for (size_t i = 1; i < old_cnode->size(); i++) {
|
||||
(void)inputs.emplace_back(old_cnode->input(i));
|
||||
}
|
||||
FuncGraphPtr func_graph = out_conf->func_graph();
|
||||
auto new_cnode = func_graph->NewCNode(inputs);
|
||||
// Continue to eval new_cnode.
|
||||
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
|
||||
return engine->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
};
|
||||
|
||||
class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
||||
public:
|
||||
PyInterpretEvaluator() : TransitionPrimEvaluator("PyInterpretEvaluator") {}
|
||||
|
@ -2085,6 +2203,7 @@ void InitPrimEvaluatorConstructors() {
|
|||
constructor[prim::kPrimGetAttr] = std::make_shared<GetAttrEvaluator>();
|
||||
constructor[prim::kPrimResolve] = std::make_shared<ResolveEvaluator>();
|
||||
constructor[prim::kPrimCreateInstance] = std::make_shared<CreateInstanceEvaluator>();
|
||||
constructor[prim::kPrimCallInstance] = std::make_shared<CallInstanceEvaluator>();
|
||||
constructor[prim::kPrimPartial] = std::make_shared<PartialEvaluator>();
|
||||
constructor[prim::kPrimPyInterpret] = std::make_shared<PyInterpretEvaluator>();
|
||||
constructor[prim::kPrimMakeTuple] = std::make_shared<MakeTupleEvaluator>();
|
||||
|
|
|
@ -82,7 +82,7 @@ bool CheckAbstractScalar(const AnfNodePtr &node) {
|
|||
if (abstract->isa<AbstractScalar>()) {
|
||||
TypePtr type = abstract->GetTypeTrack();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
if (type->isa<EnvType>()) {
|
||||
if (type->isa<EnvType>() || type->isa<MsClassType>()) {
|
||||
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << abstract->ToString() << ", node: " << node->DebugString();
|
||||
}
|
||||
if (type->isa<Problem>() || type->isa<External>()) {
|
||||
|
|
|
@ -345,6 +345,24 @@ class MS_CORE_API Problem final : public Type {
|
|||
};
|
||||
using ProblemPtr = std::shared_ptr<Problem>;
|
||||
|
||||
/// \brief MsClassType defines a type which is ms_class.
|
||||
class MS_CORE_API MsClassType final : public Type {
|
||||
public:
|
||||
/// \brief The constructor of External.
|
||||
///
|
||||
/// \return The instance of External.
|
||||
MsClassType() : Type(kObjectTypeClass) {}
|
||||
|
||||
/// \brief The destructor of External.
|
||||
~MsClassType() override = default;
|
||||
MS_DECLARE_PARENT(MsClassType, Type)
|
||||
|
||||
TypeId generic_type_id() const override { return kObjectTypeClass; }
|
||||
TypePtr DeepCopy() const override { return std::make_shared<MsClassType>(); }
|
||||
std::string DumpText() const override { return "MsClassType"; }
|
||||
};
|
||||
using MsClassTypePtr = std::shared_ptr<MsClassType>;
|
||||
|
||||
/// \brief External defines a type which is external.
|
||||
class MS_CORE_API External final : public Type {
|
||||
public:
|
||||
|
@ -360,9 +378,6 @@ class MS_CORE_API External final : public Type {
|
|||
TypeId generic_type_id() const override { return kMetaTypeExternal; }
|
||||
TypePtr DeepCopy() const override { return std::make_shared<External>(); }
|
||||
std::string DumpText() const override { return "ExternalType"; }
|
||||
|
||||
private:
|
||||
TypePtr kind;
|
||||
};
|
||||
using ExternalPtr = std::shared_ptr<External>;
|
||||
|
||||
|
|
|
@ -939,6 +939,7 @@ GVAR_DEF(PrimitivePtr, kPrimResolve, std::make_shared<Primitive>("resolve"));
|
|||
GVAR_DEF(PrimitivePtr, kPrimEmbed, std::make_shared<Primitive>("embed"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimRefToEmbed, std::make_shared<Primitive>("RefToEmbed"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCreateInstance, std::make_shared<Primitive>("create_instance"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCallInstance, std::make_shared<Primitive>("call_instance"));
|
||||
|
||||
// Other miscellaneous
|
||||
GVAR_DEF(PrimitivePtr, kPrimGetRefOrigin, std::make_shared<Primitive>("get_ref_origin"));
|
||||
|
|
|
@ -31,7 +31,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
|
|||
"get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary",
|
||||
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
|
||||
"InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
|
||||
"stop_gradient", "UpdateState", "Load", "Switch", "Print"};
|
||||
"stop_gradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
|
||||
#else
|
||||
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem",
|
||||
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
|
||||
|
@ -40,7 +40,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
|
|||
"identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key",
|
||||
"get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "Debug", "col2im_v1",
|
||||
"resolve", "BroadcastGradientArgs", "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
|
||||
"stop_gradient", "UpdateState", "Load", "Switch", "Print"};
|
||||
"stop_gradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
|
||||
#endif
|
||||
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather,
|
||||
prim::kPrimMicroStepAllGather};
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -18,20 +18,18 @@ Interfaces for parser module in c++.
|
|||
|
||||
from .parser import (Parser, create_instance, is_supported_create_instance_type, generate_scope,
|
||||
get_bprop_method_of_class, get_class_instance_type, get_class_member_namespace_symbol,
|
||||
create_slice_obj, get_dataclass_attributes, get_dataclass_methods, get_obj_id,
|
||||
get_module_namespace, get_obj_type, get_object_key, get_ast_type, get_node_type,
|
||||
get_args, get_args_default_values, get_ast_namespace_symbol, get_operation_symbol,
|
||||
get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name, eval_script,
|
||||
expand_expr_statement, is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor,
|
||||
get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
|
||||
get_ms_class_attr)
|
||||
create_slice_obj, get_obj_id, get_module_namespace, get_obj_type, get_object_key,
|
||||
get_ast_type, get_node_type, get_args, get_args_default_values, get_ast_namespace_symbol,
|
||||
get_operation_symbol, get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name,
|
||||
eval_script, expand_expr_statement, is_class_member, parse_cb, resolve_symbol,
|
||||
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
|
||||
is_class_type, get_dataclass_attributes, get_dataclass_methods)
|
||||
|
||||
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
|
||||
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type',
|
||||
'get_args_default_values', 'get_ast_namespace_symbol', 'get_operation_namespace_symbol',
|
||||
'get_args', 'get_obj_type', 'create_instance', 'is_supported_create_instance_type',
|
||||
'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser',
|
||||
'get_dataclass_attributes', 'get_dataclass_methods', 'get_dataclass_methods', 'get_scope_name',
|
||||
'eval_script', 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement',
|
||||
'generate_scope', 'get_operation_symbol', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
|
||||
'get_ms_class_attr']
|
||||
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
|
||||
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
|
||||
'create_slice_obj', 'get_obj_id', 'get_module_namespace', 'get_obj_type', 'get_object_key',
|
||||
'get_ast_type', 'get_node_type', 'get_args', 'get_args_default_values', 'get_ast_namespace_symbol',
|
||||
'get_operation_symbol', 'get_operation_namespace_symbol', 'get_parse_method_of_class', 'get_scope_name',
|
||||
'eval_script', 'expand_expr_statement', 'is_class_member', 'parse_cb', 'resolve_symbol',
|
||||
'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
|
||||
'is_class_type', 'get_dataclass_attributes', 'get_dataclass_methods']
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
#
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -118,7 +118,11 @@ class ClassMemberNamespace(Namespace):
|
|||
except ValueError:
|
||||
raise UnboundLocalError(name)
|
||||
except KeyError:
|
||||
logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.")
|
||||
# Check if cls is user-defined class decorated with ms_class. If true, an exception will be thrown.
|
||||
cls = d.__class__
|
||||
if hasattr(cls, '__ms_class__'):
|
||||
raise NotImplementedError(f"'{cls.__name__ }' object has no attribute or method: '{name}'.")
|
||||
logger.info(f"'{cls.__name__ }' object has no attribute or method: '{name}', so will return None.")
|
||||
raise AttributeError(name)
|
||||
|
||||
|
||||
|
@ -142,5 +146,4 @@ class ClassAttrNamespace(Namespace):
|
|||
except ValueError:
|
||||
raise UnboundLocalError(name)
|
||||
except KeyError:
|
||||
logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.")
|
||||
raise AttributeError(name)
|
||||
raise AttributeError(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}'.")
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
#
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -23,7 +23,6 @@ import hashlib
|
|||
import inspect
|
||||
import types
|
||||
import importlib
|
||||
from dataclasses import is_dataclass
|
||||
from textwrap import dedent
|
||||
|
||||
import asttokens
|
||||
|
@ -324,24 +323,26 @@ def get_class_instance_type(obj):
|
|||
"""Get the class instance detail type."""
|
||||
# check the obj type
|
||||
logger.debug("Get the class type(%r)", obj)
|
||||
class_type = CLASS_INSTANCE_TYPE_INVALID
|
||||
if _is_class_instance(obj):
|
||||
if isinstance(obj, nn.Cell):
|
||||
class_type = CLASS_INSTANCE_TYPE_CELL
|
||||
elif isinstance(obj, ops.Primitive):
|
||||
class_type = CLASS_INSTANCE_TYPE_PRIMITIVE
|
||||
# Add the other type base requirement
|
||||
return class_type
|
||||
if isinstance(obj, nn.Cell):
|
||||
return CLASS_INSTANCE_TYPE_CELL
|
||||
if isinstance(obj, ops.Primitive):
|
||||
return CLASS_INSTANCE_TYPE_PRIMITIVE
|
||||
return CLASS_INSTANCE_TYPE_INVALID
|
||||
|
||||
|
||||
def _is_class_instance(obj):
|
||||
"""Confirm the obj is class instance."""
|
||||
return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_dataclass_instance(obj)
|
||||
def _is_ms_class(obj):
|
||||
"""Check if obj is ms_class object."""
|
||||
return hasattr(obj, '__ms_class__')
|
||||
|
||||
|
||||
def _is_dataclass_instance(obj):
|
||||
"""Check whether a class is an instance of a dataclass (and not a dataclass itself)"""
|
||||
return is_dataclass(obj) and not isinstance(obj, type)
|
||||
return hasattr(obj, "__dataclass_fields__") and not isinstance(obj, type)
|
||||
|
||||
|
||||
def _is_class_instance(obj):
|
||||
"""Confirm the obj is class instance."""
|
||||
return isinstance(obj, (nn.Cell, ops.Primitive)) or _is_dataclass_instance(obj) or _is_ms_class(obj)
|
||||
|
||||
|
||||
def _convert_tuple_to_args_kwargs(params):
|
||||
|
@ -358,7 +359,7 @@ def _convert_tuple_to_args_kwargs(params):
|
|||
|
||||
def is_supported_create_instance_type(cls_type):
|
||||
"""Check if cls_type is a supported instance type."""
|
||||
return issubclass(cls_type, (nn.Cell, ops.Primitive))
|
||||
return issubclass(cls_type, (nn.Cell, ops.Primitive)) or _is_ms_class(cls_type)
|
||||
|
||||
|
||||
def create_instance(cls_type, params=None):
|
||||
|
@ -440,28 +441,19 @@ def get_dataclass_methods(cls):
|
|||
return methods
|
||||
|
||||
|
||||
def is_class_type(cls):
|
||||
"""Check if cls is a class type."""
|
||||
return isinstance(cls, type)
|
||||
|
||||
|
||||
def get_ms_class_name(cls):
|
||||
"""Get the name of the class instance decorated by ms_class."""
|
||||
# Check if cls is nn.Cell.
|
||||
if isinstance(cls, nn.Cell):
|
||||
raise TypeError(f"ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.")
|
||||
if isinstance(cls, type):
|
||||
name = cls.__name__
|
||||
else:
|
||||
name = cls.__class__.__name__
|
||||
# Get the name of cls.
|
||||
cls_name = cls.__module__ + '.' + name
|
||||
return cls_name
|
||||
|
||||
|
||||
def get_ms_class_attr(cls, name: str):
|
||||
"""Get attribute or method of ms_class obj."""
|
||||
# Don't take into account python magic methods and private variables.
|
||||
if name.startswith('_'):
|
||||
raise AttributeError(f"{name} is a private variable or magic method, which is not supported.")
|
||||
if not hasattr(cls, name):
|
||||
raise AttributeError(f"{cls} has no attribute: {name}.")
|
||||
return getattr(cls, name)
|
||||
return cls.__name__
|
||||
return cls.__class__.__name__
|
||||
|
||||
|
||||
def convert_to_ms_tensor(data):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
#
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -15,10 +15,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""standard_method"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mindspore import Tensor, Parameter, CSRTensor, COOTensor
|
||||
from mindspore import Tensor, Parameter, CSRTensor, COOTensor, ms_class
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
from ..._checkparam import Validator as validator
|
||||
|
@ -1828,16 +1825,16 @@ def float_floordiv(x, y):
|
|||
#############
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ms_class
|
||||
class SequenceIterator:
|
||||
"""
|
||||
SequenceIterator is a util dataclass for iterating sequence object.
|
||||
SequenceIterator is a util class for iterating sequence object.
|
||||
|
||||
Iterator to use for sequences like List, Array.
|
||||
"""
|
||||
|
||||
idx: int
|
||||
seq: list
|
||||
def __init__(self, idx, seq):
|
||||
self.idx = idx
|
||||
self.seq = seq
|
||||
|
||||
@core(ignore_values=True)
|
||||
def __ms_hasnext__(self):
|
||||
|
|
|
@ -0,0 +1,451 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback """
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, context, ms_class
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_class_attr():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Access the attributes of user-defined classes decorated by ms_class.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.number = Tensor(1, dtype=mstype.int32)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = InnerNet()
|
||||
|
||||
def construct(self):
|
||||
out = self.inner_net.number
|
||||
return out
|
||||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert out.asnumpy() == 1
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_class_method():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Access the methods of user-defined classes decorated by ms_class.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.val = Tensor(2, dtype=mstype.int32)
|
||||
|
||||
def act(self, x, y):
|
||||
return self.val * (x + y)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = InnerNet()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.inner_net.act(x, y)
|
||||
return out
|
||||
|
||||
x = Tensor(2, dtype=mstype.int32)
|
||||
y = Tensor(3, dtype=mstype.int32)
|
||||
net = Net()
|
||||
out = net(x, y)
|
||||
assert out.asnumpy() == 10
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_class_call():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Call the __call__ function of user-defined classes decorated by ms_class.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
def __call__(self, x, y):
|
||||
return self.val * (x + y)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, val):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = InnerNet(val)
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.inner_net(x, y)
|
||||
return out
|
||||
|
||||
val = Tensor(2, dtype=mstype.int32)
|
||||
x = Tensor(3, dtype=mstype.int32)
|
||||
y = Tensor(4, dtype=mstype.int32)
|
||||
net = Net(val)
|
||||
out = net(x, y)
|
||||
assert out.asnumpy() == 14
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_class_input_attr():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Access the attributes of user-defined classes decorated by ms_class.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.number = Tensor(np.array([1, 2, 3]))
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = net()
|
||||
|
||||
def construct(self):
|
||||
out = self.inner_net.number
|
||||
return out
|
||||
|
||||
net = Net(InnerNet)
|
||||
out = net()
|
||||
expect_res = np.array([1, 2, 3])
|
||||
assert np.all(out.asnumpy() == expect_res)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_class_input_method():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Access the methods of user-defined classes decorated by ms_class.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.val = Tensor(2, dtype=mstype.int32)
|
||||
|
||||
def act(self, x, y):
|
||||
return self.val * (x + y)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = net()
|
||||
|
||||
def construct(self):
|
||||
out = self.inner_net.act(1, 2)
|
||||
return out
|
||||
|
||||
net = Net(InnerNet)
|
||||
out = net()
|
||||
assert out.asnumpy() == 6
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_class_class_nested():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test nested ms_class in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class Inner:
|
||||
def __init__(self):
|
||||
self.number = Tensor(1, dtype=mstype.int32)
|
||||
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.inner = Inner()
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = InnerNet()
|
||||
|
||||
def construct(self):
|
||||
out = self.inner_net.inner.number
|
||||
return out
|
||||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert out.asnumpy() == 1
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_class_cell_nested():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test nested ms_class and cell in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, val):
|
||||
super().__init__()
|
||||
self.val = val
|
||||
|
||||
def construct(self, x):
|
||||
return x + self.val
|
||||
|
||||
@ms_class
|
||||
class TrainNet():
|
||||
class Loss(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x):
|
||||
out = self.net(x)
|
||||
return out * 2
|
||||
|
||||
def __init__(self, net):
|
||||
self.net = net
|
||||
loss_net = self.Loss(self.net)
|
||||
self.number = loss_net(10)
|
||||
|
||||
global_net = Net(1)
|
||||
class LearnNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.value = TrainNet(global_net).number
|
||||
|
||||
def construct(self, x):
|
||||
return x + self.value
|
||||
|
||||
leanrn_net = LearnNet()
|
||||
out = leanrn_net(3)
|
||||
print(out)
|
||||
assert out == 25
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_class_type_attr():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Access the attributes of class type.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
val = Tensor(2, dtype=mstype.int32)
|
||||
|
||||
def act(self, x, y):
|
||||
return self.val * (x + y)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = InnerNet
|
||||
|
||||
# Support accessing attributes of class type, but do not support
|
||||
# accessing methods, e.g. self.inner_net.act(1, 2)
|
||||
def construct(self):
|
||||
out = self.inner_net.val
|
||||
return out
|
||||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert out == 2
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_class_create_instance_attr():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Access the attributes of the created class instance.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self, val):
|
||||
self.number = val + 3
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = InnerNet
|
||||
|
||||
def construct(self, x):
|
||||
net = self.inner_net(x)
|
||||
return net.number
|
||||
|
||||
net = Net()
|
||||
out = net(2)
|
||||
assert out == 5
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_class_create_instance_method():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Access the methods of the created class instance.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self, val):
|
||||
self.number = val
|
||||
|
||||
def act(self, x, y):
|
||||
return self.number * (x + y)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = InnerNet
|
||||
|
||||
def construct(self, x, y, z):
|
||||
net = self.inner_net(x)
|
||||
return net.act(y, z)
|
||||
|
||||
x = 2
|
||||
y = Tensor(2, dtype=mstype.int32)
|
||||
z = Tensor(3, dtype=mstype.int32)
|
||||
net = Net()
|
||||
out = net(x, y, z)
|
||||
assert out.asnumpy() == 10
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_class_create_instance_call():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Call the __call__ function of the created class instance.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self, number):
|
||||
self.number = number
|
||||
|
||||
def __call__(self, x, y):
|
||||
return self.number * (x + y)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = InnerNet
|
||||
|
||||
def construct(self, x, y, z):
|
||||
net = self.inner_net(x)
|
||||
out = net(y, z)
|
||||
return out
|
||||
|
||||
x = 2
|
||||
y = Tensor(2, dtype=mstype.int32)
|
||||
z = Tensor(3, dtype=mstype.int32)
|
||||
net = Net()
|
||||
out = net(x, y, z)
|
||||
assert out == 10
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_raise_error_not_class_type():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Decorator ms_class cannot be used for non-class types.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
with pytest.raises(TypeError):
|
||||
@ms_class
|
||||
def func(x, y):
|
||||
return x + y
|
||||
|
||||
func(1, 2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_raise_error_decorate_cell():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Decorator ms_class cannot be used for nn.Cell
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
x = Tensor(1)
|
||||
net = Net()
|
||||
net(x)
|
|
@ -1,416 +0,0 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback """
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, context, ms_class, ms_function
|
||||
from . import test_graph_fallback
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_fallback_self_attr():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.attr in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Network, self).__init__()
|
||||
self.dim = 1
|
||||
|
||||
def construct(self, x):
|
||||
batch = x.shape[0]
|
||||
one = Tensor(np.ones([batch, self.dim]), mstype.float32)
|
||||
return one * x
|
||||
|
||||
net = Network()
|
||||
x = Tensor([1, 2], mstype.float32)
|
||||
out = net(x)
|
||||
expect = np.array([[1., 2.], [1., 2.]])
|
||||
assert np.allclose(out.asnumpy(), expect, 1.e-2, 1.e-2)
|
||||
|
||||
|
||||
def test_fallback_self_attr_fn():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.attr in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def __init__(self, fn):
|
||||
super(Network, self).__init__()
|
||||
self.fn = fn
|
||||
|
||||
def construct(self):
|
||||
x = np.array([1, 2, 3])
|
||||
y = np.array([3, 4, 5])
|
||||
out = Tensor(self.fn(x, y))
|
||||
return out
|
||||
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
net = Network(fn)
|
||||
out = net()
|
||||
expect = np.array([4, 6, 8])
|
||||
assert np.all(out.asnumpy() == expect)
|
||||
|
||||
|
||||
def test_fallback_self_attr_attr():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.attr in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Network, self).__init__()
|
||||
self.value = [2, 2, 3]
|
||||
|
||||
def construct(self):
|
||||
x = np.array(self.value.count(2))
|
||||
return Tensor(x)
|
||||
|
||||
net = Network()
|
||||
out = net()
|
||||
assert out == 2
|
||||
|
||||
|
||||
def test_fallback_self_method():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.method in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def construct(self):
|
||||
x = np.array([1, 2, 3])
|
||||
y = np.array([3, 4, 5])
|
||||
out = Tensor(self.fn(x, y))
|
||||
return out
|
||||
|
||||
def fn(self, x, y):
|
||||
return x + y
|
||||
|
||||
net = Network()
|
||||
out = net()
|
||||
expect = np.array([4, 6, 8])
|
||||
assert np.all(out.asnumpy() == expect)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
|
||||
def test_fallback_self_method_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test self.method in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def construct(self):
|
||||
x = np.array([1, 2, 3])
|
||||
y = np.array([3, 4, 5])
|
||||
z = self.fn(x, y)
|
||||
out = Tensor(z)
|
||||
return out
|
||||
|
||||
def fn(self, x, y):
|
||||
return x + y
|
||||
|
||||
net = Network()
|
||||
out = net()
|
||||
print(out)
|
||||
|
||||
|
||||
def test_fallback_import_modules():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: add_func is defined in test_graph_fallback.py
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def use_imported_module(x, y):
|
||||
out = test_graph_fallback.add_func(x, y)
|
||||
return out
|
||||
|
||||
x = Tensor(2, dtype=mstype.int32)
|
||||
y = Tensor(3, dtype=mstype.int32)
|
||||
out = use_imported_module(x, y)
|
||||
print(out)
|
||||
|
||||
|
||||
def test_fallback_class_attr():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test user-defined class attributes in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.number = 1
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = InnerNet()
|
||||
|
||||
def construct(self):
|
||||
out = self.inner_net.number
|
||||
return out
|
||||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert out == 1
|
||||
|
||||
|
||||
def test_fallback_class_method():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test user-defined class methods in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.val = 2
|
||||
|
||||
def act(self, x, y):
|
||||
return self.val * (x + y)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = InnerNet()
|
||||
|
||||
def construct(self):
|
||||
out = self.inner_net.act(1, 2)
|
||||
return out
|
||||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert out == 6
|
||||
|
||||
|
||||
def test_fallback_class_input_attr():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test user-defined class attributes in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.number = Tensor(np.array([1, 2, 3]))
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = net()
|
||||
|
||||
def construct(self):
|
||||
out = self.inner_net.number
|
||||
return out
|
||||
|
||||
net = Net(InnerNet)
|
||||
out = net()
|
||||
expect_res = np.array([1, 2, 3])
|
||||
assert np.all(out.asnumpy() == expect_res)
|
||||
|
||||
|
||||
def test_fallback_class_input_method():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test user-defined class methods in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.val = 2
|
||||
|
||||
def act(self, x, y):
|
||||
return self.val * (x + y)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = net()
|
||||
|
||||
def construct(self):
|
||||
out = self.inner_net.act(1, 2)
|
||||
return out
|
||||
|
||||
net = Net(InnerNet)
|
||||
out = net()
|
||||
assert out == 6
|
||||
|
||||
|
||||
def test_fallback_class_class_nested():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test nested ms_class in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class Inner:
|
||||
def __init__(self):
|
||||
self.number = 1
|
||||
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.inner = Inner()
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = InnerNet()
|
||||
|
||||
def construct(self):
|
||||
out = self.inner_net.inner.number
|
||||
return out
|
||||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert out == 1
|
||||
|
||||
|
||||
def test_fallback_class_cell_nested():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test nested ms_class and cell in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, val):
|
||||
super().__init__()
|
||||
self.val = val
|
||||
|
||||
def construct(self, x):
|
||||
return x + self.val
|
||||
|
||||
@ms_class
|
||||
class TrainNet():
|
||||
class Loss(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, x):
|
||||
out = self.net(x)
|
||||
return out * 2
|
||||
|
||||
def __init__(self, net):
|
||||
self.net = net
|
||||
loss_net = self.Loss(self.net)
|
||||
self.number = loss_net(10)
|
||||
|
||||
global_net = Net(1)
|
||||
class LearnNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.value = TrainNet(global_net).number
|
||||
|
||||
def construct(self, x):
|
||||
return x + self.value
|
||||
|
||||
leanrn_net = LearnNet()
|
||||
out = leanrn_net(3)
|
||||
print(out)
|
||||
assert out == 25
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph yet')
|
||||
def test_fallback_class_isinstance():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test ms_class in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.number = 1
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.inner_net = InnerNet()
|
||||
|
||||
def construct(self, x):
|
||||
if isinstance(self.inner_net, InnerNet):
|
||||
return x + 10
|
||||
return x
|
||||
|
||||
net = Net()
|
||||
out = net(5)
|
||||
assert out == 15
|
||||
|
||||
|
||||
def test_fallback_raise_error_not_class_type():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test ms_class in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
with pytest.raises(TypeError):
|
||||
@ms_class
|
||||
def func(x, y):
|
||||
return x + y
|
||||
|
||||
func(1, 2)
|
||||
|
||||
|
||||
def test_fallback_raise_error_not_class_instance():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test ms_class in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class InnerNet:
|
||||
def __init__(self):
|
||||
self.number = 1
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self):
|
||||
out = InnerNet().number
|
||||
return out
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
net = Net()
|
||||
net()
|
||||
|
||||
|
||||
def test_fallback_raise_error_decorate_cell():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test ms_class in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_class
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
x = Tensor(1)
|
||||
net = Net()
|
||||
net(x)
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback """
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, context, ms_function
|
||||
from . import test_graph_fallback
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_fallback_self_attr():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Use self.attr in expressions supported by JIT Fallback.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Network, self).__init__()
|
||||
self.dim = 1
|
||||
|
||||
def construct(self, x):
|
||||
batch = x.shape[0]
|
||||
one = Tensor(np.ones([batch, self.dim]), mstype.float32)
|
||||
return one * x
|
||||
|
||||
net = Network()
|
||||
x = Tensor([1, 2], mstype.float32)
|
||||
out = net(x)
|
||||
expect = np.array([[1., 2.], [1., 2.]])
|
||||
assert np.allclose(out.asnumpy(), expect, 1.e-2, 1.e-2)
|
||||
|
||||
|
||||
def test_fallback_self_attr_fn():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Use self.attr of type function in expressions supported by JIT Fallback.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def __init__(self, fn):
|
||||
super(Network, self).__init__()
|
||||
self.fn = fn
|
||||
|
||||
def construct(self):
|
||||
x = np.array([1, 2, 3])
|
||||
y = np.array([3, 4, 5])
|
||||
out = Tensor(self.fn(x, y))
|
||||
return out
|
||||
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
net = Network(fn)
|
||||
out = net()
|
||||
expect = np.array([4, 6, 8])
|
||||
assert np.all(out.asnumpy() == expect)
|
||||
|
||||
|
||||
def test_fallback_self_attr_attr():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: In expressions supported by JIT Fallback, use the attribute of self.attr.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Network, self).__init__()
|
||||
self.value = [2, 2, 3]
|
||||
|
||||
def construct(self):
|
||||
x = np.array(self.value.count(2))
|
||||
return Tensor(x)
|
||||
|
||||
net = Network()
|
||||
out = net()
|
||||
assert out == 2
|
||||
|
||||
|
||||
def test_fallback_self_method():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Use self.method in expressions supported by JIT Fallback.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class Network(nn.Cell):
|
||||
def construct(self):
|
||||
x = np.array([1, 2, 3])
|
||||
y = np.array([3, 4, 5])
|
||||
out = Tensor(self.fn(x, y))
|
||||
return out
|
||||
|
||||
def fn(self, x, y):
|
||||
return x + y
|
||||
|
||||
net = Network()
|
||||
out = net()
|
||||
expect = np.array([4, 6, 8])
|
||||
assert np.all(out.asnumpy() == expect)
|
||||
|
||||
|
||||
def test_fallback_import_modules():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Check whether the call to the third-party library is correct. It has nothing to do with class.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def use_imported_module(x, y):
|
||||
out = test_graph_fallback.add_func(x, y)
|
||||
return out
|
||||
|
||||
x = Tensor(2, dtype=mstype.int32)
|
||||
y = Tensor(3, dtype=mstype.int32)
|
||||
out = use_imported_module(x, y)
|
||||
print(out)
|
Loading…
Reference in New Issue