support graph hasattr and getattr
This commit is contained in:
parent
a11e01a0d2
commit
88bc967cdd
|
@ -46,7 +46,11 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr
|
|||
if (IsValueNode<parse::MsClassObject>(object_node)) {
|
||||
auto ms_class = GetValueNode<parse::MsClassObjectPtr>(object_node)->obj();
|
||||
auto attr_str = GetValue<std::string>(GetValueNode(attr_node));
|
||||
return parse::ResolveMsClassWithAttr(optimizer->manager(), ms_class, attr_str, node);
|
||||
auto new_node = parse::ResolveMsClassWithAttr(optimizer->manager(), ms_class, attr_str, node);
|
||||
if (new_node == nullptr || IsValueNode<None>(new_node)) {
|
||||
MS_EXCEPTION(AttributeError) << py::str(ms_class) << " object has no attribute: " << attr_str << ".";
|
||||
}
|
||||
return new_node;
|
||||
}
|
||||
// {prim::kPrimGetAttr, bool, attr}
|
||||
if (IsValueNode<BoolImm>(object_node)) {
|
||||
|
|
|
@ -495,6 +495,9 @@ AnfNodePtr ResolveSequenceWithAttr(const FuncGraphManagerPtr &manager, const py:
|
|||
MS_EXCEPTION_IF_NULL(attr_str_ptr);
|
||||
const auto &attr_str = attr_str_ptr->value();
|
||||
auto res = ResolveMsClassWithAttr(manager, sequence[i], attr_str, get_attr_node);
|
||||
if (res == nullptr || IsValueNode<None>(res)) {
|
||||
MS_EXCEPTION(AttributeError) << py::str(sequence[i]) << " object has no attribute: " << attr_str << ".";
|
||||
}
|
||||
(void)inputs.emplace_back(res);
|
||||
}
|
||||
} else {
|
||||
|
@ -536,9 +539,12 @@ AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const py::
|
|||
MS_LOG(EXCEPTION) << attr << " is a private variable or magic method, which is not supported.";
|
||||
}
|
||||
if (!py::hasattr(cls_obj, common::SafeCStr(attr))) {
|
||||
MS_LOG(EXCEPTION) << py::str(cls_obj) << " has not attribute: " << attr << ".";
|
||||
return nullptr;
|
||||
}
|
||||
py::object attr_obj = py::getattr(cls_obj, common::SafeCStr(attr));
|
||||
if (py::isinstance<py::none>(attr_obj)) {
|
||||
return nullptr;
|
||||
}
|
||||
AnfNodePtr res_node = ResolveObjectAndAddToManager(manager, attr_obj, get_attr_node);
|
||||
TraceManager::ClearParseOrResolveDebugInfo();
|
||||
return res_node;
|
||||
|
|
|
@ -1368,11 +1368,13 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_
|
|||
return eng->ForwardConfig(old_conf, fn_conf);
|
||||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &, const ValuePtr &attr_value,
|
||||
const ValuePtr &data_value, const AnfNodeConfigPtr &out_conf) {
|
||||
EvalResultPtr GetEvaluatedValueForNameSpaceString(const AbstractBasePtrList &args_spec_list, const ValuePtr &data_value,
|
||||
const AnfNodeConfigPtr &out_conf, const std::string &data) {
|
||||
constexpr size_t item_index = 1;
|
||||
auto item_args = args_spec_list[item_index];
|
||||
ValuePtr item_value = item_args->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(data_value);
|
||||
MS_EXCEPTION_IF_NULL(attr_value);
|
||||
ValuePtr item_value = attr_value;
|
||||
MS_EXCEPTION_IF_NULL(item_value);
|
||||
if (item_value->isa<StringImm>()) {
|
||||
item_value = std::make_shared<parse::Symbol>(item_value->cast_ptr<StringImm>()->value());
|
||||
}
|
||||
|
@ -1391,6 +1393,23 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &, con
|
|||
if (new_node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Resolve node failed";
|
||||
}
|
||||
if (IsValueNode<None>(new_node)) {
|
||||
// Do not find the attribute.
|
||||
constexpr auto kMaxArgsLen = 3;
|
||||
bool has_default = (args_spec_list.size() == kMaxArgsLen);
|
||||
if (!has_default) {
|
||||
MS_EXCEPTION(AttributeError) << data << " object has no attribute " << symbol->symbol();
|
||||
}
|
||||
auto out_cnode = out_node->cast_ptr<CNode>();
|
||||
MS_EXCEPTION_IF_NULL(out_cnode);
|
||||
constexpr auto kDefaultIndex = 3;
|
||||
auto default_node = out_cnode->inputs()[kDefaultIndex];
|
||||
func_graph->ReplaceInOrder(out_node, default_node);
|
||||
auto eng = out_conf->engine();
|
||||
MS_EXCEPTION_IF_NULL(eng);
|
||||
auto fn_conf = eng->MakeConfig(default_node, out_conf->context(), out_conf->func_graph());
|
||||
return eng->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
if (pipeline::GetJitLevel() == "O0" && IsValueNode<FuncGraph>(new_node)) {
|
||||
UpdateDebugInfo(GetValueNode<FuncGraphPtr>(new_node), out_node->scope(), out_node->debug_info());
|
||||
}
|
||||
|
@ -1404,7 +1423,7 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &, con
|
|||
return eng->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForNameSpace(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list,
|
||||
EvalResultPtr GetEvaluatedValueForNameSpace(const AbstractBasePtrList &args_spec_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
// args_spec_list: same as StaticGetter
|
||||
constexpr size_t args_min_size = 2;
|
||||
|
@ -1413,23 +1432,33 @@ EvalResultPtr GetEvaluatedValueForNameSpace(const AnalysisEnginePtr &, const Abs
|
|||
}
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
// An external type.
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
|
||||
auto data_value = args_spec_list[0]->BuildValue();
|
||||
constexpr auto kDataIndex = 0;
|
||||
constexpr auto kItemIndex = 1;
|
||||
auto data = args_spec_list[kDataIndex];
|
||||
auto item = args_spec_list[kItemIndex];
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
auto data_value = data->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(data_value);
|
||||
if (!data_value->isa<parse::NameSpace>()) {
|
||||
MS_EXCEPTION(TypeError) << "Not supported to get attribute for " << data_value->ToString()
|
||||
<< "\nThe first argument should be a NameSpace, but got " << args_spec_list[0]->ToString();
|
||||
<< "\nThe first argument should be a NameSpace, but got " << data->ToString();
|
||||
}
|
||||
|
||||
auto item_value = args_spec_list[1]->BuildValue();
|
||||
auto item_value = item->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(item_value);
|
||||
return GetEvaluatedValueForNameSpaceString(nullptr, item_value, data_value, out_conf);
|
||||
auto data_type = data->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(data_type);
|
||||
const auto &data_id_str = TypeIdToString(data_type->type_id());
|
||||
return GetEvaluatedValueForNameSpaceString(args_spec_list, data_value, out_conf, data_id_str);
|
||||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AnalysisEnginePtr &, const ValuePtr &item_value,
|
||||
const ValuePtr &data_value, const ConfigPtr &,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AbstractBasePtrList &args_spec_list,
|
||||
const ValuePtr &data_value, const AnfNodeConfigPtr &out_conf) {
|
||||
constexpr size_t item_index = 1;
|
||||
auto item_args = args_spec_list[item_index];
|
||||
ValuePtr item_value = item_args->BuildValue();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(item_value);
|
||||
MS_EXCEPTION_IF_NULL(data_value);
|
||||
// Get the name of item.
|
||||
|
@ -1447,8 +1476,20 @@ EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AnalysisEnginePtr &,
|
|||
// Get the attr/method of ms_class object.
|
||||
auto out_node = out_conf->node();
|
||||
FuncGraphPtr func_graph = out_node->func_graph();
|
||||
// If the attribute is not found and the default is not set, AttributeError will be raised.
|
||||
auto new_node = parse::ResolveMsClassWithAttr(func_graph->manager(), ms_class->obj(), item_name, out_node);
|
||||
// Replace old node with the resolved new node in order list.
|
||||
if (new_node == nullptr || IsValueNode<None>(new_node)) {
|
||||
constexpr auto kMaxArgsLen = 3;
|
||||
bool has_default = (args_spec_list.size() == kMaxArgsLen);
|
||||
if (!has_default) {
|
||||
MS_EXCEPTION(AttributeError) << py::str(ms_class->obj()) << " object has no attribute: " << item_name << ".";
|
||||
}
|
||||
constexpr auto kDefaultIndex = 3;
|
||||
auto out_cnode = out_node->cast_ptr<CNode>();
|
||||
MS_EXCEPTION_IF_NULL(out_cnode);
|
||||
new_node = out_cnode->inputs()[kDefaultIndex];
|
||||
}
|
||||
|
||||
func_graph->ReplaceInOrder(out_node, new_node);
|
||||
AnalysisEnginePtr eng = out_conf->engine();
|
||||
MS_EXCEPTION_IF_NULL(eng);
|
||||
|
@ -1456,9 +1497,12 @@ EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AnalysisEnginePtr &,
|
|||
return eng->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForCellAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value,
|
||||
const FuncGraphPtr &func_value, const ConfigPtr &,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
EvalResultPtr GetEvaluatedValueForCellAttrOrMethod(const AbstractBasePtrList &args_spec_list,
|
||||
const FuncGraphPtr &func_value, const AnfNodeConfigPtr &out_conf) {
|
||||
constexpr size_t item_index = 1;
|
||||
auto item_args = args_spec_list[item_index];
|
||||
ValuePtr item_value = item_args->BuildValue();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(item_value);
|
||||
MS_EXCEPTION_IF_NULL(func_value);
|
||||
if (!item_value->isa<StringImm>()) {
|
||||
|
@ -1471,22 +1515,32 @@ EvalResultPtr GetEvaluatedValueForCellAttrOrMethod(const AnalysisEnginePtr &engi
|
|||
auto wrapper_obj = dyn_cast_ptr<parse::PyObjectWrapper>(python_obj);
|
||||
MS_EXCEPTION_IF_NULL(wrapper_obj);
|
||||
py::object real_python_obj = wrapper_obj->obj();
|
||||
const auto &py_obj_str = py::str(real_python_obj);
|
||||
MS_LOG(DEBUG) << "item_value: " << item_value->ToString() << ", func_value: " << func_value->ToString()
|
||||
<< ", real_python_obj: " << py::str(real_python_obj);
|
||||
<< ", real_python_obj: " << py_obj_str;
|
||||
if (py::isinstance<Cell>(real_python_obj)) {
|
||||
py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
|
||||
py::object ns_obj =
|
||||
python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, real_python_obj);
|
||||
auto ns = std::make_shared<parse::NameSpace>(parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, ns_obj);
|
||||
return GetEvaluatedValueForNameSpaceString(nullptr, item_value, ns, out_conf);
|
||||
return GetEvaluatedValueForNameSpaceString(args_spec_list, ns, out_conf, py_obj_str);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value,
|
||||
const TypePtr &data_type, const ConfigPtr &data_conf,
|
||||
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine,
|
||||
const AbstractBasePtrList &args_spec_list,
|
||||
const ConfigPtr &data_conf,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
constexpr size_t data_index = 0;
|
||||
constexpr size_t item_index = 1;
|
||||
auto data_args = args_spec_list[data_index];
|
||||
auto item_args = args_spec_list[item_index];
|
||||
MS_EXCEPTION_IF_NULL(data_args);
|
||||
MS_EXCEPTION_IF_NULL(item_args);
|
||||
ValuePtr item_value = item_args->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(item_value);
|
||||
TypePtr data_type = data_args->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(data_type);
|
||||
// The method maybe a Primitive or Composite
|
||||
if (!item_value->isa<StringImm>()) {
|
||||
|
@ -1498,8 +1552,22 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePt
|
|||
if (require.empty()) {
|
||||
require = pipeline::Resource::GetAttrPtr(data_type->type_id(), item_name);
|
||||
if (require.empty()) {
|
||||
MS_LOG(EXCEPTION) << "MindSpore not support to get attribute \'" << item_name << "\' of a type["
|
||||
<< data_type->ToString() << "]";
|
||||
constexpr auto kMaxArgsLen = 3;
|
||||
bool has_default = (args_spec_list.size() == kMaxArgsLen);
|
||||
if (!has_default) {
|
||||
MS_EXCEPTION(AttributeError) << data_type->ToString() << " object has no attribute: " << item_name;
|
||||
}
|
||||
auto out_node = out_conf->node();
|
||||
auto out_cnode = out_node->cast_ptr<CNode>();
|
||||
MS_EXCEPTION_IF_NULL(out_cnode);
|
||||
auto fg = out_cnode->func_graph();
|
||||
constexpr auto kDefaultIndex = 3;
|
||||
auto default_node = out_cnode->inputs()[kDefaultIndex];
|
||||
fg->ReplaceInOrder(out_node, default_node);
|
||||
auto eng = out_conf->engine();
|
||||
MS_EXCEPTION_IF_NULL(eng);
|
||||
auto fn_conf = eng->MakeConfig(default_node, out_conf->context(), out_conf->func_graph());
|
||||
return eng->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
require_type = REQUIRE_TYPE::ATTR;
|
||||
}
|
||||
|
@ -1550,7 +1618,6 @@ ValuePtr GetMsClassObject(const AbstractBasePtr &abs) {
|
|||
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
|
||||
// Inputs: namespace and its static function; or class and its member function
|
||||
CheckArgsSize("StaticGetter", args_spec_list, 2);
|
||||
|
||||
constexpr size_t data_index = 0;
|
||||
constexpr size_t item_index = 1;
|
||||
|
@ -1571,14 +1638,32 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
|
|||
MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
|
||||
}
|
||||
|
||||
if (data_args->isa<abstract::AbstractScalar>()) {
|
||||
ValuePtr data_value = data_args->BuildValue();
|
||||
if (data_value->isa<parse::InterpretedObject>()) {
|
||||
MS_EXCEPTION(TypeError) << "Do not support to get attribute from interpret object.";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr auto KMaxArgSize = 3;
|
||||
if (args_spec_list.size() == KMaxArgSize) {
|
||||
constexpr size_t default_index = 2;
|
||||
auto default_args = args_spec_list[default_index];
|
||||
if (default_args->isa<abstract::AbstractScalar>()) {
|
||||
ValuePtr default_value = default_args->BuildValue();
|
||||
if (default_value->isa<parse::InterpretedObject>()) {
|
||||
MS_EXCEPTION(TypeError) << "For 'getattr', the third input 'default' can not be interpreted object.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto class_value = GetMsClassObject(data_args);
|
||||
if (class_value != nullptr) {
|
||||
return GetEvaluatedValueForMsClassAttrOrMethod(engine, item_value, class_value, data_conf, out_conf);
|
||||
return GetEvaluatedValueForMsClassAttrOrMethod(args_spec_list, class_value, out_conf);
|
||||
}
|
||||
auto data_func_graph = dyn_cast_ptr<FuncGraphAbstractClosure>(data_args);
|
||||
if (data_func_graph != nullptr) {
|
||||
auto res =
|
||||
GetEvaluatedValueForCellAttrOrMethod(engine, item_value, data_func_graph->func_graph(), data_conf, out_conf);
|
||||
auto res = GetEvaluatedValueForCellAttrOrMethod(args_spec_list, data_func_graph->func_graph(), out_conf);
|
||||
if (res != nullptr) {
|
||||
return res;
|
||||
}
|
||||
|
@ -1586,9 +1671,9 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
|
|||
// Try to search method map, if not found, the data_type should be External type.
|
||||
TypePtr data_type = data_args->BuildType();
|
||||
if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) {
|
||||
return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, item_value, data_type, data_conf, out_conf);
|
||||
return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, args_spec_list, data_conf, out_conf);
|
||||
}
|
||||
return GetEvaluatedValueForNameSpace(engine, args_spec_list, out_conf);
|
||||
return GetEvaluatedValueForNameSpace(args_spec_list, out_conf);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -1748,15 +1833,25 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
|
|||
MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
|
||||
constexpr auto kGetAttrArgSize = 2;
|
||||
constexpr auto kGetAttrArgMinSize = 2;
|
||||
constexpr auto kGetAttrArgMaxSize = 3;
|
||||
constexpr auto kAttrIndex = 1;
|
||||
auto ret_abstract = EvalUndeterminedArgs(args_spec_list);
|
||||
if (ret_abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined";
|
||||
return ret_abstract;
|
||||
}
|
||||
// Inputs: data, item
|
||||
if (args_spec_list.size() != kGetAttrArgSize) {
|
||||
MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
|
||||
const auto args_size = args_spec_list.size();
|
||||
if (args_size != kGetAttrArgMinSize && args_size != kGetAttrArgMaxSize) {
|
||||
MS_LOG(EXCEPTION) << "For Primitive GetAttr, the input size should be 2 or 3, but got size:" << args_size;
|
||||
}
|
||||
auto attr_abs = args_spec_list[kAttrIndex];
|
||||
auto attr_abs_type = attr_abs->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(attr_abs_type);
|
||||
auto type_id = attr_abs_type->type_id();
|
||||
if (type_id != TypeId::kObjectTypeString) {
|
||||
MS_EXCEPTION(TypeError) << "getattr(): attribute name must be string but got: " << TypeIdToString(type_id);
|
||||
}
|
||||
EvalResultPtr ret = nullptr;
|
||||
if (bound_node() != nullptr) {
|
||||
|
|
|
@ -107,12 +107,12 @@ _unsupported_internal_type = (
|
|||
)
|
||||
|
||||
_hybrid_type = (
|
||||
print, len, enumerate, zip, map, filter, abs, all, any, round, max, min,
|
||||
print, len, enumerate, zip, map, filter, abs, all, any, round, max, min, hasattr
|
||||
)
|
||||
|
||||
# Unsupported python builtin type in JIT Fallback.
|
||||
_fallback_unsupported_python_builtin_type = (
|
||||
compile, eval, exec, input, open, delattr, setattr, getattr, hasattr, super, staticmethod, classmethod, __import__,
|
||||
compile, eval, exec, input, open, delattr, setattr, super, staticmethod, classmethod, __import__,
|
||||
memoryview, property,
|
||||
)
|
||||
|
||||
|
|
|
@ -143,6 +143,8 @@ convert_object_map = {
|
|||
T.isinstance: Primitive('isinstance'),
|
||||
T.max: M.ms_max,
|
||||
T.min: M.ms_min,
|
||||
T.getattr: Primitive('getattr'),
|
||||
T.hasattr: M.hasattr,
|
||||
|
||||
# custom define operation
|
||||
T.iter: M.ms_iter,
|
||||
|
|
|
@ -258,6 +258,21 @@ def strides_(x):
|
|||
return strides
|
||||
|
||||
|
||||
def hasattr(x, attr): # pylint: disable=redefined-builtin
|
||||
"""
|
||||
Return whether an object has the attribute.
|
||||
|
||||
Args:
|
||||
x (object): Input object.
|
||||
attr (string): The name of attribute
|
||||
|
||||
Returns:
|
||||
Boolean value, indicates whether the object x has attribute attr.
|
||||
"""
|
||||
out = getattr(x, attr, None)
|
||||
return out is not None
|
||||
|
||||
|
||||
def astype(x, dtype, copy=True): # pylint: disable=redefined-outer-name
|
||||
"""
|
||||
Return a copy of the tensor, casted to a specified type.
|
||||
|
|
|
@ -28,7 +28,7 @@ from operator import ( # noqa
|
|||
|
||||
# support system function call
|
||||
from builtins import ( # noqa
|
||||
bool, getattr, setattr, len, iter, next, pow, range, map, zip,
|
||||
bool, getattr, setattr, hasattr, len, iter, next, pow, range, map, zip,
|
||||
print, enumerate, isinstance, filter, abs, all, any, round, max, min
|
||||
)
|
||||
|
||||
|
@ -45,7 +45,7 @@ from numpy import ( # noqa
|
|||
__all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt', 'gt', 'le', 'ge', 'pos', 'neg',
|
||||
'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains',
|
||||
'matmul', 'getitem', 'setitem',
|
||||
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
|
||||
'bool', 'getattr', 'setattr', 'hasattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
|
||||
'partial', 'print', 'enumerate', 'isinstance', 'filter', 'abs', 'all', 'any', 'round',
|
||||
'exp', 'log', 'sin', 'cos', 'tan', 'max', 'min']
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ def test_catch_exception_stack_trace_log():
|
|||
assert os.path.exists(log_file_name)
|
||||
with open(log_file_name, "r") as f_first:
|
||||
data_first = f_first.read()
|
||||
assert "Not supported to get attribute" in data_first
|
||||
assert "Do not support to get attribute" in data_first
|
||||
assert "x = self.y.tt1" in data_first
|
||||
|
||||
# Clean files
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test graph getattr, hasattr"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, ms_function, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_getattr_tensor():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support tensor input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(x):
|
||||
abs_func = getattr(x, "abs")
|
||||
return abs_func()
|
||||
|
||||
out = foo(Tensor([-1, -2, -3]))
|
||||
assert np.all(out.asnumpy() == np.array([1, 2, 3]))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_getattr_tensor_with_concate_string():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support tensor input and concate string.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(x):
|
||||
attr_str = "a" + "bs"
|
||||
abs_func = getattr(x, attr_str)
|
||||
return abs_func()
|
||||
|
||||
out = foo(Tensor([-1, -2, -3]))
|
||||
assert np.all(out.asnumpy() == np.array([1, 2, 3]))
|
|
@ -0,0 +1,705 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test graph getattr"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, ms_function, ms_class, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_getattr_tensor_2():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support tensor input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = Tensor([-1, -2, -3])
|
||||
abs_func = getattr(x, "abs")
|
||||
return abs_func()
|
||||
|
||||
out = foo()
|
||||
assert np.all(out.asnumpy() == np.array([1, 2, 3]))
|
||||
|
||||
|
||||
def test_getattr_tensor_with_concate_string_2():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support tensor input and concate string.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
attr_str = "a" + "bs"
|
||||
abs_func = getattr(Tensor([-1, -2, -3]), attr_str)
|
||||
return abs_func()
|
||||
|
||||
out = foo()
|
||||
assert np.all(out.asnumpy() == np.array([1, 2, 3]))
|
||||
|
||||
|
||||
def test_getattr_tensor_with_wrong_attr():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support tensor input.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(x):
|
||||
abs_func = getattr(x, "abs2")
|
||||
return abs_func()
|
||||
|
||||
with pytest.raises(AttributeError) as err:
|
||||
foo(Tensor([-1, -2, -3]))
|
||||
assert "object has no attribute" in str(err.value)
|
||||
|
||||
|
||||
def test_getattr_tensor_with_default():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support tensor input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
abs_func = getattr(Tensor([-1, -2, -3]), "abs", Tensor([-1, -2, -3]))
|
||||
return abs_func()
|
||||
|
||||
out = foo()
|
||||
assert np.all(out.asnumpy() == np.array([1, 2, 3]))
|
||||
|
||||
|
||||
def test_getattr_tensor_with_default_2():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support tensor input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
abs_func = getattr(Tensor([-1, -2, -3]), "abs2", Tensor([-1, -2, -3]))
|
||||
return abs_func
|
||||
|
||||
out = foo()
|
||||
assert np.all(out.asnumpy() == np.array([-1, -2, -3]))
|
||||
|
||||
|
||||
def test_getattr_list():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support list input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = [1, 2, 3, 4]
|
||||
abs_func = getattr(x, "__len__")
|
||||
return abs_func()
|
||||
|
||||
out = foo()
|
||||
assert out == 4
|
||||
|
||||
|
||||
def test_getattr_list_2():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support list input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(x):
|
||||
abs_func = getattr(x, "__len__")
|
||||
return abs_func()
|
||||
|
||||
out = foo([1, 2, 3, 4])
|
||||
assert out == 4
|
||||
|
||||
|
||||
def test_getattr_list_with_concate_input():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support list input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = [1, 2, 3, 4]
|
||||
attr_str = "__" + "len" + "__"
|
||||
abs_func = getattr(x, attr_str)
|
||||
return abs_func()
|
||||
|
||||
out = foo()
|
||||
assert out == 4
|
||||
|
||||
|
||||
def test_getattr_with_concate_input_2():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support list input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(x):
|
||||
attr_str = "__" + "len" + "__"
|
||||
abs_func = getattr(x, attr_str)
|
||||
return abs_func()
|
||||
|
||||
out = foo([1, 2, 3, 4])
|
||||
assert out == 4
|
||||
|
||||
|
||||
def test_getattr_list_with_default():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support list input with default.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = [1, 2, 3, 4]
|
||||
abs_func = getattr(x, "__len__", Tensor([-1]))
|
||||
return abs_func()
|
||||
|
||||
out = foo()
|
||||
assert out == 4
|
||||
|
||||
|
||||
def test_getattr_list_with_default_2():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support list input with default.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = [1, 2, 3, 4]
|
||||
abs_func = getattr(x, "__len2__", Tensor([-1]))
|
||||
return abs_func
|
||||
|
||||
out = foo()
|
||||
assert out == -1
|
||||
|
||||
|
||||
def test_getattr_list_with_wrong_attr():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support list input.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(x):
|
||||
abs_func = getattr(x, "abs2")
|
||||
return abs_func()
|
||||
|
||||
with pytest.raises(AttributeError) as err:
|
||||
foo([1, 2, 3, 4])
|
||||
assert "object has no attribute" in str(err.value)
|
||||
|
||||
|
||||
def test_getattr_tuple():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support tuple input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = (1, 2, 3, 4)
|
||||
abs_func = getattr(x, "__len__")
|
||||
return abs_func()
|
||||
|
||||
out = foo()
|
||||
assert out == 4
|
||||
|
||||
|
||||
def test_getattr_tuple_2():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support tuple input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(x):
|
||||
abs_func = getattr(x, "__len__")
|
||||
return abs_func()
|
||||
|
||||
out = foo((1, 2, 3, 4))
|
||||
assert out == 4
|
||||
|
||||
|
||||
def test_getattr_tuple_with_concate_input():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support tuple input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = (1, 2, 3, 4)
|
||||
attr_str = "__" + "len" + "__"
|
||||
abs_func = getattr(x, attr_str)
|
||||
return abs_func()
|
||||
|
||||
out = foo()
|
||||
assert out == 4
|
||||
|
||||
|
||||
def test_tuple_getattr_with_concate_input_2():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support tuple input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(x):
|
||||
attr_str = "__" + "len" + "__"
|
||||
abs_func = getattr(x, attr_str)
|
||||
return abs_func()
|
||||
|
||||
out = foo((1, 2, 3, 4))
|
||||
assert out == 4
|
||||
|
||||
|
||||
def test_getattr_tuple_with_default():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support tuple input with default.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = (1, 2, 3, 4)
|
||||
abs_func = getattr(x, "__len__", Tensor([-1]))
|
||||
return abs_func()
|
||||
|
||||
out = foo()
|
||||
assert out == 4
|
||||
|
||||
|
||||
def test_getattr_tuple_with_default_2():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support tuple input with default.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = (1, 2, 3, 4)
|
||||
abs_func = getattr(x, "__len2__", Tensor([-1]))
|
||||
return abs_func
|
||||
|
||||
out = foo()
|
||||
assert out == -1
|
||||
|
||||
|
||||
def test_getattr_tuple_with_wrong_attr():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support tensor input.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(x):
|
||||
abs_func = getattr(x, "shape")
|
||||
return abs_func()
|
||||
|
||||
with pytest.raises(AttributeError) as err:
|
||||
foo((1, 2, 3, 4))
|
||||
assert "object has no attribute" in str(err.value)
|
||||
|
||||
|
||||
def test_getattr_dict():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support dict input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = {"1": 1, "2": 2}
|
||||
abs_func = getattr(x, "__len__")
|
||||
return abs_func()
|
||||
|
||||
out = foo()
|
||||
assert out == 2
|
||||
|
||||
|
||||
def test_getattr_dict_2():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support dict input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(x):
|
||||
abs_func = getattr(x, "__len__")
|
||||
return abs_func()
|
||||
|
||||
out = foo({"1": 1, "2": 2})
|
||||
assert out == 2
|
||||
|
||||
|
||||
def test_getattr_dict_with_concate_input():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support dict input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = {"1": 1, "2": 2}
|
||||
attr_str = "__" + "len" + "__"
|
||||
abs_func = getattr(x, attr_str)
|
||||
return abs_func()
|
||||
|
||||
out = foo()
|
||||
assert out == 2
|
||||
|
||||
|
||||
def test_getattr_dict_with_concate_input_2():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support dict input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(x):
|
||||
attr_str = "__" + "len" + "__"
|
||||
abs_func = getattr(x, attr_str)
|
||||
return abs_func()
|
||||
|
||||
out = foo({"1": 1, "2": 2})
|
||||
assert out == 2
|
||||
|
||||
|
||||
def test_getattr_dict_with_wrong_attr():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support tensor input.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(x):
|
||||
abs_func = getattr(x, "abs2")
|
||||
return abs_func()
|
||||
|
||||
with pytest.raises(AttributeError) as err:
|
||||
foo({"1": 1, "2": 2})
|
||||
assert "object has no attribute" in str(err.value)
|
||||
|
||||
|
||||
def test_getattr_dict_with_default():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support dict input with default.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = {"1": 1, "2": 2}
|
||||
abs_func = getattr(x, "__len__", Tensor([-1]))
|
||||
return abs_func()
|
||||
|
||||
out = foo()
|
||||
assert out == 2
|
||||
|
||||
|
||||
def test_getattr_dict_with_default_2():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support dict input with default.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = {"1": 1, "2": 2}
|
||||
abs_func = getattr(x, "__len2__", Tensor([-1]))
|
||||
return abs_func
|
||||
|
||||
out = foo()
|
||||
assert out == -1
|
||||
|
||||
|
||||
@ms_class
|
||||
class MSClass1:
|
||||
def __init__(self):
|
||||
self.num0 = Tensor(0)
|
||||
self.num1 = Tensor(1)
|
||||
self.num2 = Tensor(2)
|
||||
self.num3 = Tensor(3)
|
||||
self.none = None
|
||||
|
||||
|
||||
def test_getattr_ms_class():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support ms_class input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
ms_obj = MSClass1()
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
return getattr(ms_obj, "num1")
|
||||
|
||||
out = foo()
|
||||
assert out == 1
|
||||
|
||||
|
||||
def test_getattr_ms_class_with_concate_attr():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support ms_class input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
ms_obj = MSClass1()
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
ret = 0
|
||||
nums = ["0", "1", "2", "3"]
|
||||
for i in range(4):
|
||||
attr_str = "num" + nums[i]
|
||||
ret = ret + getattr(ms_obj, attr_str)
|
||||
return ret
|
||||
|
||||
out = foo()
|
||||
assert out == 6
|
||||
|
||||
|
||||
def test_getattr_ms_class_with_default():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support ms_class input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
ms_obj = MSClass1()
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
return getattr(ms_obj, "none", 10)
|
||||
|
||||
out = foo()
|
||||
assert out == 10
|
||||
|
||||
|
||||
def test_getattr_ms_class_with_concate_attr_and_default():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support ms_class input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
ms_obj = MSClass1()
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
ret = 0
|
||||
nums = ["0", "1", "2", "3", "4"]
|
||||
for i in range(5):
|
||||
attr_str = "num" + nums[i]
|
||||
ret = ret + getattr(ms_obj, attr_str, Tensor([4]))
|
||||
return ret
|
||||
|
||||
out = foo()
|
||||
assert out == 10
|
||||
|
||||
|
||||
def test_getattr_ms_class_with_wrong_attr():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support list input.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
ms_obj = MSClass1()
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
abs_func = getattr(ms_obj, "abs2")
|
||||
return abs_func()
|
||||
|
||||
with pytest.raises(AttributeError) as err:
|
||||
foo()
|
||||
assert "object has no attribute" in str(err.value)
|
||||
|
||||
|
||||
def test_getattr_ms_class_with_wrong_attr_2():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support list input.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
ms_obj = MSClass1()
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
abs_func = getattr(ms_obj, "none")
|
||||
return abs_func
|
||||
|
||||
with pytest.raises(AttributeError) as err:
|
||||
foo()
|
||||
assert "object has no attribute" in str(err.value)
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.a0 = Tensor([0])
|
||||
self.a1 = Tensor([1])
|
||||
self.a2 = Tensor([2])
|
||||
self.a3 = Tensor([3])
|
||||
|
||||
def construct(self):
|
||||
return self.a0
|
||||
|
||||
|
||||
def test_getattr_cell_obj():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support cell object input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
cell_obj = Net()
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
return getattr(cell_obj, "a0")
|
||||
|
||||
out = foo()
|
||||
assert out == 0
|
||||
|
||||
|
||||
def test_getattr_cell_obj_concate_input():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support cell object input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
cell_obj = Net()
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
a = 0
|
||||
attrs = ["0", "1", "2", "3"]
|
||||
for attr in attrs:
|
||||
a = a + getattr(cell_obj, "a" + attr)
|
||||
return a
|
||||
|
||||
out = foo()
|
||||
assert out == 6
|
||||
|
||||
|
||||
def test_getattr_cell_obj_concate_input_and_default_value():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support cell object input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
cell_obj = Net()
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
a = 0
|
||||
attrs = ["0", "1", "2", "3", "4"]
|
||||
for attr in attrs:
|
||||
a = a + getattr(cell_obj, "a" + attr, Tensor([4]))
|
||||
return a
|
||||
|
||||
out = foo()
|
||||
assert out == 10
|
||||
|
||||
|
||||
def test_getattr_cell_obj_with_wrong_attr():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support tensor input.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
cell_obj = Net()
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
abs_func = getattr(cell_obj, "foo")
|
||||
return abs_func()
|
||||
|
||||
with pytest.raises(AttributeError) as err:
|
||||
foo()
|
||||
assert "object has no attribute" in str(err.value)
|
||||
|
||||
|
||||
def test_getattr_numpy_array():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support numpy array input.
|
||||
Expectation: TypeError
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = np.array([1, 2, 3, 4])
|
||||
return getattr(x, "shape")[0]
|
||||
|
||||
with pytest.raises(TypeError) as err:
|
||||
foo()
|
||||
assert "Do not support to get attribute from interpret object." in str(err.value)
|
||||
|
||||
|
||||
def test_getattr_numpy_array_2():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support numpy array input.
|
||||
Expectation: TypeError
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = 1
|
||||
return getattr(x, "shape", np.array([0, 1, 2, 3, 4]))
|
||||
|
||||
with pytest.raises(TypeError) as err:
|
||||
foo()
|
||||
assert "For 'getattr', the third input 'default' can not be interpreted object." in str(err.value)
|
|
@ -0,0 +1,303 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test graph hasattr"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, ms_function, ms_class, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_hasattr_tensor():
|
||||
"""
|
||||
Feature: Syntax hasattr.
|
||||
Description: Graph syntax hasattr support tensor input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(x):
|
||||
return hasattr(x, "abs")
|
||||
|
||||
assert foo(Tensor([-1, -2, -3]))
|
||||
|
||||
|
||||
def test_hasattr_tensor_2():
|
||||
"""
|
||||
Feature: Syntax hasattr.
|
||||
Description: Graph syntax hasattr support tensor input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo(x):
|
||||
return hasattr(x, "abs2")
|
||||
|
||||
assert not foo(Tensor([-1, -2, -3]))
|
||||
|
||||
|
||||
def test_hasattr_list():
|
||||
"""
|
||||
Feature: Syntax hasattr.
|
||||
Description: Graph syntax hasattr support list input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = [1, 2, 3, 4]
|
||||
attr = "__" + "len" + "__"
|
||||
return hasattr(x, attr)
|
||||
|
||||
assert foo()
|
||||
|
||||
|
||||
def test_hasattr_list_2():
|
||||
"""
|
||||
Feature: Syntax hasattr.
|
||||
Description: Graph syntax hasattr support list input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = [1, 2, 3, 4]
|
||||
attr = "__" + "len2" + "__"
|
||||
return hasattr(x, attr)
|
||||
|
||||
assert not foo()
|
||||
|
||||
|
||||
def test_hasattr_tuple():
|
||||
"""
|
||||
Feature: Syntax hasattr.
|
||||
Description: Graph syntax hasattr support tuple input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = (1, 2, 3, 4)
|
||||
attr = "__" + "len" + "__"
|
||||
return hasattr(x, attr)
|
||||
|
||||
assert foo()
|
||||
|
||||
|
||||
def test_hasattr_tuple_2():
|
||||
"""
|
||||
Feature: Syntax hasattr.
|
||||
Description: Graph syntax hasattr support tuple input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = (1, 2, 3, 4)
|
||||
attr = "__" + "len2" + "__"
|
||||
return hasattr(x, attr)
|
||||
|
||||
assert not foo()
|
||||
|
||||
|
||||
def test_hasattr_dict():
|
||||
"""
|
||||
Feature: Syntax hasattr.
|
||||
Description: Graph syntax hasattr support dict input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = {"1": 1, "2": 2}
|
||||
attr = "__" + "len" + "__"
|
||||
return hasattr(x, attr)
|
||||
|
||||
assert foo()
|
||||
|
||||
|
||||
def test_hasattr_dict_2():
|
||||
"""
|
||||
Feature: Syntax hasattr.
|
||||
Description: Graph syntax hasattr support dict input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = {"1": 1, "2": 2}
|
||||
attr = "__" + "len2" + "__"
|
||||
return hasattr(x, attr)
|
||||
|
||||
assert not foo()
|
||||
|
||||
|
||||
@ms_class
|
||||
class MSClass1:
|
||||
def __init__(self):
|
||||
self.num0 = Tensor(0)
|
||||
self.num1 = Tensor(1)
|
||||
self.num2 = Tensor(2)
|
||||
self.num3 = Tensor(3)
|
||||
self.none = None
|
||||
|
||||
|
||||
def test_hasattr_ms_class():
|
||||
"""
|
||||
Feature: Syntax hasattr.
|
||||
Description: Graph syntax hasattr support ms_class input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
ms_obj = MSClass1()
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
return hasattr(ms_obj, "num1")
|
||||
|
||||
assert foo()
|
||||
|
||||
|
||||
def test_hasattr_ms_class_2():
|
||||
"""
|
||||
Feature: Syntax hasattr.
|
||||
Description: Graph syntax hasattr support ms_class input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
ms_obj = MSClass1()
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
return hasattr(ms_obj, "none")
|
||||
|
||||
assert not foo()
|
||||
|
||||
|
||||
def test_hasattr_ms_class_with_concate_attr():
|
||||
"""
|
||||
Feature: Syntax hasattr.
|
||||
Description: Graph syntax hasattr support ms_class input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
ms_obj = MSClass1()
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
ret = 0
|
||||
nums = ["0", "1", "2", "3", "4"]
|
||||
for i in range(5):
|
||||
attr_str = "num" + nums[i]
|
||||
if hasattr(ms_obj, attr_str):
|
||||
ret = ret + 1
|
||||
return ret
|
||||
|
||||
out = foo()
|
||||
assert out == 4
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.a0 = Tensor([0])
|
||||
self.a1 = Tensor([1])
|
||||
self.a2 = Tensor([2])
|
||||
self.a3 = Tensor([3])
|
||||
self.none = None
|
||||
|
||||
def construct(self):
|
||||
return self.a0
|
||||
|
||||
|
||||
def test_hasattr_cell_obj():
|
||||
"""
|
||||
Feature: Syntax hasattr.
|
||||
Description: Graph syntax hasattr support cell object input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
cell_obj = Net()
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
return hasattr(cell_obj, "a0")
|
||||
|
||||
assert foo()
|
||||
|
||||
|
||||
def test_hasattr_cell_obj_2():
|
||||
"""
|
||||
Feature: Syntax hasattr.
|
||||
Description: Graph syntax hasattr support cell object input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
cell_obj = Net()
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
return hasattr(cell_obj, "none")
|
||||
|
||||
assert not foo()
|
||||
|
||||
|
||||
def test_hasattr_cell_obj_concate_input():
|
||||
"""
|
||||
Feature: Syntax hasattr.
|
||||
Description: Graph syntax hasattr support cell object input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
cell_obj = Net()
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
a = 0
|
||||
attrs = ["0", "1", "2", "3", "4"]
|
||||
for attr in attrs:
|
||||
if hasattr(cell_obj, "a" + attr):
|
||||
a = a + 1
|
||||
return a
|
||||
|
||||
out = foo()
|
||||
assert out == 4
|
||||
|
||||
|
||||
def test_hasattr_numpy_array():
|
||||
"""
|
||||
Feature: Syntax hasattr.
|
||||
Description: Graph syntax hasattr support numpy array input.
|
||||
Expectation: TypeError
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = np.array([1, 2, 3, 4])
|
||||
return hasattr(x, "shape")
|
||||
|
||||
assert foo()
|
||||
|
||||
|
||||
def test_hasattr_numpy_array_2():
|
||||
"""
|
||||
Feature: Syntax hasattr.
|
||||
Description: Graph syntax hasattr support numpy array input.
|
||||
Expectation: TypeError
|
||||
"""
|
||||
|
||||
@ms_function
|
||||
def foo():
|
||||
x = np.array([1, 2, 3, 4])
|
||||
return hasattr(x, "shape2")
|
||||
|
||||
assert not foo()
|
|
@ -27,8 +27,6 @@ context.set_context(mode=context.GRAPH_MODE)
|
|||
|
||||
def test_dtype_and_shape_as_attr():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
shape = x.shape
|
||||
|
@ -62,8 +60,6 @@ def test_dtype_and_shape_as_attr_to_new_tensor():
|
|||
|
||||
def test_type_not_have_the_attr():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
shape = x.shapes
|
||||
|
@ -71,14 +67,12 @@ def test_type_not_have_the_attr():
|
|||
|
||||
net = Net()
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
with pytest.raises(RuntimeError):
|
||||
with pytest.raises(AttributeError):
|
||||
net(x)
|
||||
|
||||
|
||||
def test_type_not_have_the_method():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
shape = x.dtypes()
|
||||
|
@ -86,5 +80,5 @@ def test_type_not_have_the_method():
|
|||
|
||||
net = Net()
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
with pytest.raises(RuntimeError):
|
||||
with pytest.raises(AttributeError):
|
||||
net(x)
|
||||
|
|
Loading…
Reference in New Issue