support graph hasattr and getattr

This commit is contained in:
liangzhibo 2022-08-12 17:25:20 +08:00
parent a11e01a0d2
commit 88bc967cdd
12 changed files with 1236 additions and 48 deletions

View File

@ -46,7 +46,11 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr
if (IsValueNode<parse::MsClassObject>(object_node)) {
auto ms_class = GetValueNode<parse::MsClassObjectPtr>(object_node)->obj();
auto attr_str = GetValue<std::string>(GetValueNode(attr_node));
return parse::ResolveMsClassWithAttr(optimizer->manager(), ms_class, attr_str, node);
auto new_node = parse::ResolveMsClassWithAttr(optimizer->manager(), ms_class, attr_str, node);
if (new_node == nullptr || IsValueNode<None>(new_node)) {
MS_EXCEPTION(AttributeError) << py::str(ms_class) << " object has no attribute: " << attr_str << ".";
}
return new_node;
}
// {prim::kPrimGetAttr, bool, attr}
if (IsValueNode<BoolImm>(object_node)) {

View File

@ -495,6 +495,9 @@ AnfNodePtr ResolveSequenceWithAttr(const FuncGraphManagerPtr &manager, const py:
MS_EXCEPTION_IF_NULL(attr_str_ptr);
const auto &attr_str = attr_str_ptr->value();
auto res = ResolveMsClassWithAttr(manager, sequence[i], attr_str, get_attr_node);
if (res == nullptr || IsValueNode<None>(res)) {
MS_EXCEPTION(AttributeError) << py::str(sequence[i]) << " object has no attribute: " << attr_str << ".";
}
(void)inputs.emplace_back(res);
}
} else {
@ -536,9 +539,12 @@ AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const py::
MS_LOG(EXCEPTION) << attr << " is a private variable or magic method, which is not supported.";
}
if (!py::hasattr(cls_obj, common::SafeCStr(attr))) {
MS_LOG(EXCEPTION) << py::str(cls_obj) << " has not attribute: " << attr << ".";
return nullptr;
}
py::object attr_obj = py::getattr(cls_obj, common::SafeCStr(attr));
if (py::isinstance<py::none>(attr_obj)) {
return nullptr;
}
AnfNodePtr res_node = ResolveObjectAndAddToManager(manager, attr_obj, get_attr_node);
TraceManager::ClearParseOrResolveDebugInfo();
return res_node;

View File

@ -1368,11 +1368,13 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_
return eng->ForwardConfig(old_conf, fn_conf);
}
EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &, const ValuePtr &attr_value,
const ValuePtr &data_value, const AnfNodeConfigPtr &out_conf) {
EvalResultPtr GetEvaluatedValueForNameSpaceString(const AbstractBasePtrList &args_spec_list, const ValuePtr &data_value,
const AnfNodeConfigPtr &out_conf, const std::string &data) {
constexpr size_t item_index = 1;
auto item_args = args_spec_list[item_index];
ValuePtr item_value = item_args->BuildValue();
MS_EXCEPTION_IF_NULL(data_value);
MS_EXCEPTION_IF_NULL(attr_value);
ValuePtr item_value = attr_value;
MS_EXCEPTION_IF_NULL(item_value);
if (item_value->isa<StringImm>()) {
item_value = std::make_shared<parse::Symbol>(item_value->cast_ptr<StringImm>()->value());
}
@ -1391,6 +1393,23 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &, con
if (new_node == nullptr) {
MS_LOG(EXCEPTION) << "Resolve node failed";
}
if (IsValueNode<None>(new_node)) {
// Do not find the attribute.
constexpr auto kMaxArgsLen = 3;
bool has_default = (args_spec_list.size() == kMaxArgsLen);
if (!has_default) {
MS_EXCEPTION(AttributeError) << data << " object has no attribute " << symbol->symbol();
}
auto out_cnode = out_node->cast_ptr<CNode>();
MS_EXCEPTION_IF_NULL(out_cnode);
constexpr auto kDefaultIndex = 3;
auto default_node = out_cnode->inputs()[kDefaultIndex];
func_graph->ReplaceInOrder(out_node, default_node);
auto eng = out_conf->engine();
MS_EXCEPTION_IF_NULL(eng);
auto fn_conf = eng->MakeConfig(default_node, out_conf->context(), out_conf->func_graph());
return eng->ForwardConfig(out_conf, fn_conf);
}
if (pipeline::GetJitLevel() == "O0" && IsValueNode<FuncGraph>(new_node)) {
UpdateDebugInfo(GetValueNode<FuncGraphPtr>(new_node), out_node->scope(), out_node->debug_info());
}
@ -1404,7 +1423,7 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &, con
return eng->ForwardConfig(out_conf, fn_conf);
}
EvalResultPtr GetEvaluatedValueForNameSpace(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list,
EvalResultPtr GetEvaluatedValueForNameSpace(const AbstractBasePtrList &args_spec_list,
const AnfNodeConfigPtr &out_conf) {
// args_spec_list: same as StaticGetter
constexpr size_t args_min_size = 2;
@ -1413,23 +1432,33 @@ EvalResultPtr GetEvaluatedValueForNameSpace(const AnalysisEnginePtr &, const Abs
}
MS_EXCEPTION_IF_NULL(out_conf);
// An external type.
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
auto data_value = args_spec_list[0]->BuildValue();
constexpr auto kDataIndex = 0;
constexpr auto kItemIndex = 1;
auto data = args_spec_list[kDataIndex];
auto item = args_spec_list[kItemIndex];
MS_EXCEPTION_IF_NULL(data);
MS_EXCEPTION_IF_NULL(item);
auto data_value = data->BuildValue();
MS_EXCEPTION_IF_NULL(data_value);
if (!data_value->isa<parse::NameSpace>()) {
MS_EXCEPTION(TypeError) << "Not supported to get attribute for " << data_value->ToString()
<< "\nThe first argument should be a NameSpace, but got " << args_spec_list[0]->ToString();
<< "\nThe first argument should be a NameSpace, but got " << data->ToString();
}
auto item_value = args_spec_list[1]->BuildValue();
auto item_value = item->BuildValue();
MS_EXCEPTION_IF_NULL(item_value);
return GetEvaluatedValueForNameSpaceString(nullptr, item_value, data_value, out_conf);
auto data_type = data->BuildType();
MS_EXCEPTION_IF_NULL(data_type);
const auto &data_id_str = TypeIdToString(data_type->type_id());
return GetEvaluatedValueForNameSpaceString(args_spec_list, data_value, out_conf, data_id_str);
}
EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AnalysisEnginePtr &, const ValuePtr &item_value,
const ValuePtr &data_value, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) {
EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AbstractBasePtrList &args_spec_list,
const ValuePtr &data_value, const AnfNodeConfigPtr &out_conf) {
constexpr size_t item_index = 1;
auto item_args = args_spec_list[item_index];
ValuePtr item_value = item_args->BuildValue();
MS_EXCEPTION_IF_NULL(item_value);
MS_EXCEPTION_IF_NULL(data_value);
// Get the name of item.
@ -1447,8 +1476,20 @@ EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AnalysisEnginePtr &,
// Get the attr/method of ms_class object.
auto out_node = out_conf->node();
FuncGraphPtr func_graph = out_node->func_graph();
// If the attribute is not found and the default is not set, AttributeError will be raised.
auto new_node = parse::ResolveMsClassWithAttr(func_graph->manager(), ms_class->obj(), item_name, out_node);
// Replace old node with the resolved new node in order list.
if (new_node == nullptr || IsValueNode<None>(new_node)) {
constexpr auto kMaxArgsLen = 3;
bool has_default = (args_spec_list.size() == kMaxArgsLen);
if (!has_default) {
MS_EXCEPTION(AttributeError) << py::str(ms_class->obj()) << " object has no attribute: " << item_name << ".";
}
constexpr auto kDefaultIndex = 3;
auto out_cnode = out_node->cast_ptr<CNode>();
MS_EXCEPTION_IF_NULL(out_cnode);
new_node = out_cnode->inputs()[kDefaultIndex];
}
func_graph->ReplaceInOrder(out_node, new_node);
AnalysisEnginePtr eng = out_conf->engine();
MS_EXCEPTION_IF_NULL(eng);
@ -1456,9 +1497,12 @@ EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AnalysisEnginePtr &,
return eng->ForwardConfig(out_conf, fn_conf);
}
EvalResultPtr GetEvaluatedValueForCellAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value,
const FuncGraphPtr &func_value, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) {
EvalResultPtr GetEvaluatedValueForCellAttrOrMethod(const AbstractBasePtrList &args_spec_list,
const FuncGraphPtr &func_value, const AnfNodeConfigPtr &out_conf) {
constexpr size_t item_index = 1;
auto item_args = args_spec_list[item_index];
ValuePtr item_value = item_args->BuildValue();
MS_EXCEPTION_IF_NULL(item_value);
MS_EXCEPTION_IF_NULL(func_value);
if (!item_value->isa<StringImm>()) {
@ -1471,22 +1515,32 @@ EvalResultPtr GetEvaluatedValueForCellAttrOrMethod(const AnalysisEnginePtr &engi
auto wrapper_obj = dyn_cast_ptr<parse::PyObjectWrapper>(python_obj);
MS_EXCEPTION_IF_NULL(wrapper_obj);
py::object real_python_obj = wrapper_obj->obj();
const auto &py_obj_str = py::str(real_python_obj);
MS_LOG(DEBUG) << "item_value: " << item_value->ToString() << ", func_value: " << func_value->ToString()
<< ", real_python_obj: " << py::str(real_python_obj);
<< ", real_python_obj: " << py_obj_str;
if (py::isinstance<Cell>(real_python_obj)) {
py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
py::object ns_obj =
python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, real_python_obj);
auto ns = std::make_shared<parse::NameSpace>(parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, ns_obj);
return GetEvaluatedValueForNameSpaceString(nullptr, item_value, ns, out_conf);
return GetEvaluatedValueForNameSpaceString(args_spec_list, ns, out_conf, py_obj_str);
}
return nullptr;
}
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value,
const TypePtr &data_type, const ConfigPtr &data_conf,
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine,
const AbstractBasePtrList &args_spec_list,
const ConfigPtr &data_conf,
const AnfNodeConfigPtr &out_conf) {
constexpr size_t data_index = 0;
constexpr size_t item_index = 1;
auto data_args = args_spec_list[data_index];
auto item_args = args_spec_list[item_index];
MS_EXCEPTION_IF_NULL(data_args);
MS_EXCEPTION_IF_NULL(item_args);
ValuePtr item_value = item_args->BuildValue();
MS_EXCEPTION_IF_NULL(item_value);
TypePtr data_type = data_args->BuildType();
MS_EXCEPTION_IF_NULL(data_type);
// The method maybe a Primitive or Composite
if (!item_value->isa<StringImm>()) {
@ -1498,8 +1552,22 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePt
if (require.empty()) {
require = pipeline::Resource::GetAttrPtr(data_type->type_id(), item_name);
if (require.empty()) {
MS_LOG(EXCEPTION) << "MindSpore not support to get attribute \'" << item_name << "\' of a type["
<< data_type->ToString() << "]";
constexpr auto kMaxArgsLen = 3;
bool has_default = (args_spec_list.size() == kMaxArgsLen);
if (!has_default) {
MS_EXCEPTION(AttributeError) << data_type->ToString() << " object has no attribute: " << item_name;
}
auto out_node = out_conf->node();
auto out_cnode = out_node->cast_ptr<CNode>();
MS_EXCEPTION_IF_NULL(out_cnode);
auto fg = out_cnode->func_graph();
constexpr auto kDefaultIndex = 3;
auto default_node = out_cnode->inputs()[kDefaultIndex];
fg->ReplaceInOrder(out_node, default_node);
auto eng = out_conf->engine();
MS_EXCEPTION_IF_NULL(eng);
auto fn_conf = eng->MakeConfig(default_node, out_conf->context(), out_conf->func_graph());
return eng->ForwardConfig(out_conf, fn_conf);
}
require_type = REQUIRE_TYPE::ATTR;
}
@ -1550,7 +1618,6 @@ ValuePtr GetMsClassObject(const AbstractBasePtr &abs) {
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
// Inputs: namespace and its static function; or class and its member function
CheckArgsSize("StaticGetter", args_spec_list, 2);
constexpr size_t data_index = 0;
constexpr size_t item_index = 1;
@ -1571,14 +1638,32 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
}
if (data_args->isa<abstract::AbstractScalar>()) {
ValuePtr data_value = data_args->BuildValue();
if (data_value->isa<parse::InterpretedObject>()) {
MS_EXCEPTION(TypeError) << "Do not support to get attribute from interpret object.";
}
}
constexpr auto KMaxArgSize = 3;
if (args_spec_list.size() == KMaxArgSize) {
constexpr size_t default_index = 2;
auto default_args = args_spec_list[default_index];
if (default_args->isa<abstract::AbstractScalar>()) {
ValuePtr default_value = default_args->BuildValue();
if (default_value->isa<parse::InterpretedObject>()) {
MS_EXCEPTION(TypeError) << "For 'getattr', the third input 'default' can not be interpreted object.";
}
}
}
auto class_value = GetMsClassObject(data_args);
if (class_value != nullptr) {
return GetEvaluatedValueForMsClassAttrOrMethod(engine, item_value, class_value, data_conf, out_conf);
return GetEvaluatedValueForMsClassAttrOrMethod(args_spec_list, class_value, out_conf);
}
auto data_func_graph = dyn_cast_ptr<FuncGraphAbstractClosure>(data_args);
if (data_func_graph != nullptr) {
auto res =
GetEvaluatedValueForCellAttrOrMethod(engine, item_value, data_func_graph->func_graph(), data_conf, out_conf);
auto res = GetEvaluatedValueForCellAttrOrMethod(args_spec_list, data_func_graph->func_graph(), out_conf);
if (res != nullptr) {
return res;
}
@ -1586,9 +1671,9 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
// Try to search method map, if not found, the data_type should be External type.
TypePtr data_type = data_args->BuildType();
if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) {
return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, item_value, data_type, data_conf, out_conf);
return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, args_spec_list, data_conf, out_conf);
}
return GetEvaluatedValueForNameSpace(engine, args_spec_list, out_conf);
return GetEvaluatedValueForNameSpace(args_spec_list, out_conf);
}
} // namespace
@ -1748,15 +1833,25 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
constexpr auto kGetAttrArgSize = 2;
constexpr auto kGetAttrArgMinSize = 2;
constexpr auto kGetAttrArgMaxSize = 3;
constexpr auto kAttrIndex = 1;
auto ret_abstract = EvalUndeterminedArgs(args_spec_list);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined";
return ret_abstract;
}
// Inputs: data, item
if (args_spec_list.size() != kGetAttrArgSize) {
MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
const auto args_size = args_spec_list.size();
if (args_size != kGetAttrArgMinSize && args_size != kGetAttrArgMaxSize) {
MS_LOG(EXCEPTION) << "For Primitive GetAttr, the input size should be 2 or 3, but got size:" << args_size;
}
auto attr_abs = args_spec_list[kAttrIndex];
auto attr_abs_type = attr_abs->BuildType();
MS_EXCEPTION_IF_NULL(attr_abs_type);
auto type_id = attr_abs_type->type_id();
if (type_id != TypeId::kObjectTypeString) {
MS_EXCEPTION(TypeError) << "getattr(): attribute name must be string but got: " << TypeIdToString(type_id);
}
EvalResultPtr ret = nullptr;
if (bound_node() != nullptr) {

View File

@ -107,12 +107,12 @@ _unsupported_internal_type = (
)
_hybrid_type = (
print, len, enumerate, zip, map, filter, abs, all, any, round, max, min,
print, len, enumerate, zip, map, filter, abs, all, any, round, max, min, hasattr
)
# Unsupported python builtin type in JIT Fallback.
_fallback_unsupported_python_builtin_type = (
compile, eval, exec, input, open, delattr, setattr, getattr, hasattr, super, staticmethod, classmethod, __import__,
compile, eval, exec, input, open, delattr, setattr, super, staticmethod, classmethod, __import__,
memoryview, property,
)

View File

@ -143,6 +143,8 @@ convert_object_map = {
T.isinstance: Primitive('isinstance'),
T.max: M.ms_max,
T.min: M.ms_min,
T.getattr: Primitive('getattr'),
T.hasattr: M.hasattr,
# custom define operation
T.iter: M.ms_iter,

View File

@ -258,6 +258,21 @@ def strides_(x):
return strides
def hasattr(x, attr): # pylint: disable=redefined-builtin
"""
Return whether an object has the attribute.
Args:
x (object): Input object.
attr (string): The name of attribute
Returns:
Boolean value, indicates whether the object x has attribute attr.
"""
out = getattr(x, attr, None)
return out is not None
def astype(x, dtype, copy=True): # pylint: disable=redefined-outer-name
"""
Return a copy of the tensor, casted to a specified type.

View File

@ -28,7 +28,7 @@ from operator import ( # noqa
# support system function call
from builtins import ( # noqa
bool, getattr, setattr, len, iter, next, pow, range, map, zip,
bool, getattr, setattr, hasattr, len, iter, next, pow, range, map, zip,
print, enumerate, isinstance, filter, abs, all, any, round, max, min
)
@ -45,7 +45,7 @@ from numpy import ( # noqa
__all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt', 'gt', 'le', 'ge', 'pos', 'neg',
'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains',
'matmul', 'getitem', 'setitem',
'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
'bool', 'getattr', 'setattr', 'hasattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
'partial', 'print', 'enumerate', 'isinstance', 'filter', 'abs', 'all', 'any', 'round',
'exp', 'log', 'sin', 'cos', 'tan', 'max', 'min']

View File

@ -44,7 +44,7 @@ def test_catch_exception_stack_trace_log():
assert os.path.exists(log_file_name)
with open(log_file_name, "r") as f_first:
data_first = f_first.read()
assert "Not supported to get attribute" in data_first
assert "Do not support to get attribute" in data_first
assert "x = self.y.tt1" in data_first
# Clean files

View File

@ -0,0 +1,64 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test graph getattr, hasattr"""
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_getattr_tensor():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support tensor input.
Expectation: No exception.
"""
@ms_function
def foo(x):
abs_func = getattr(x, "abs")
return abs_func()
out = foo(Tensor([-1, -2, -3]))
assert np.all(out.asnumpy() == np.array([1, 2, 3]))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_getattr_tensor_with_concate_string():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support tensor input and concate string.
Expectation: No exception.
"""
@ms_function
def foo(x):
attr_str = "a" + "bs"
abs_func = getattr(x, attr_str)
return abs_func()
out = foo(Tensor([-1, -2, -3]))
assert np.all(out.asnumpy() == np.array([1, 2, 3]))

View File

@ -0,0 +1,705 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test graph getattr"""
import pytest
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, ms_function, ms_class, context
context.set_context(mode=context.GRAPH_MODE)
def test_getattr_tensor_2():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support tensor input.
Expectation: No exception.
"""
@ms_function
def foo():
x = Tensor([-1, -2, -3])
abs_func = getattr(x, "abs")
return abs_func()
out = foo()
assert np.all(out.asnumpy() == np.array([1, 2, 3]))
def test_getattr_tensor_with_concate_string_2():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support tensor input and concate string.
Expectation: No exception.
"""
@ms_function
def foo():
attr_str = "a" + "bs"
abs_func = getattr(Tensor([-1, -2, -3]), attr_str)
return abs_func()
out = foo()
assert np.all(out.asnumpy() == np.array([1, 2, 3]))
def test_getattr_tensor_with_wrong_attr():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support tensor input.
Expectation: AttributeError.
"""
@ms_function
def foo(x):
abs_func = getattr(x, "abs2")
return abs_func()
with pytest.raises(AttributeError) as err:
foo(Tensor([-1, -2, -3]))
assert "object has no attribute" in str(err.value)
def test_getattr_tensor_with_default():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support tensor input.
Expectation: No exception.
"""
@ms_function
def foo():
abs_func = getattr(Tensor([-1, -2, -3]), "abs", Tensor([-1, -2, -3]))
return abs_func()
out = foo()
assert np.all(out.asnumpy() == np.array([1, 2, 3]))
def test_getattr_tensor_with_default_2():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support tensor input.
Expectation: No exception.
"""
@ms_function
def foo():
abs_func = getattr(Tensor([-1, -2, -3]), "abs2", Tensor([-1, -2, -3]))
return abs_func
out = foo()
assert np.all(out.asnumpy() == np.array([-1, -2, -3]))
def test_getattr_list():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support list input.
Expectation: No exception.
"""
@ms_function
def foo():
x = [1, 2, 3, 4]
abs_func = getattr(x, "__len__")
return abs_func()
out = foo()
assert out == 4
def test_getattr_list_2():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support list input.
Expectation: No exception.
"""
@ms_function
def foo(x):
abs_func = getattr(x, "__len__")
return abs_func()
out = foo([1, 2, 3, 4])
assert out == 4
def test_getattr_list_with_concate_input():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support list input.
Expectation: No exception.
"""
@ms_function
def foo():
x = [1, 2, 3, 4]
attr_str = "__" + "len" + "__"
abs_func = getattr(x, attr_str)
return abs_func()
out = foo()
assert out == 4
def test_getattr_with_concate_input_2():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support list input.
Expectation: No exception.
"""
@ms_function
def foo(x):
attr_str = "__" + "len" + "__"
abs_func = getattr(x, attr_str)
return abs_func()
out = foo([1, 2, 3, 4])
assert out == 4
def test_getattr_list_with_default():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support list input with default.
Expectation: No exception.
"""
@ms_function
def foo():
x = [1, 2, 3, 4]
abs_func = getattr(x, "__len__", Tensor([-1]))
return abs_func()
out = foo()
assert out == 4
def test_getattr_list_with_default_2():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support list input with default.
Expectation: No exception.
"""
@ms_function
def foo():
x = [1, 2, 3, 4]
abs_func = getattr(x, "__len2__", Tensor([-1]))
return abs_func
out = foo()
assert out == -1
def test_getattr_list_with_wrong_attr():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support list input.
Expectation: AttributeError.
"""
@ms_function
def foo(x):
abs_func = getattr(x, "abs2")
return abs_func()
with pytest.raises(AttributeError) as err:
foo([1, 2, 3, 4])
assert "object has no attribute" in str(err.value)
def test_getattr_tuple():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support tuple input.
Expectation: No exception.
"""
@ms_function
def foo():
x = (1, 2, 3, 4)
abs_func = getattr(x, "__len__")
return abs_func()
out = foo()
assert out == 4
def test_getattr_tuple_2():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support tuple input.
Expectation: No exception.
"""
@ms_function
def foo(x):
abs_func = getattr(x, "__len__")
return abs_func()
out = foo((1, 2, 3, 4))
assert out == 4
def test_getattr_tuple_with_concate_input():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support tuple input.
Expectation: No exception.
"""
@ms_function
def foo():
x = (1, 2, 3, 4)
attr_str = "__" + "len" + "__"
abs_func = getattr(x, attr_str)
return abs_func()
out = foo()
assert out == 4
def test_tuple_getattr_with_concate_input_2():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support tuple input.
Expectation: No exception.
"""
@ms_function
def foo(x):
attr_str = "__" + "len" + "__"
abs_func = getattr(x, attr_str)
return abs_func()
out = foo((1, 2, 3, 4))
assert out == 4
def test_getattr_tuple_with_default():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support tuple input with default.
Expectation: No exception.
"""
@ms_function
def foo():
x = (1, 2, 3, 4)
abs_func = getattr(x, "__len__", Tensor([-1]))
return abs_func()
out = foo()
assert out == 4
def test_getattr_tuple_with_default_2():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support tuple input with default.
Expectation: No exception.
"""
@ms_function
def foo():
x = (1, 2, 3, 4)
abs_func = getattr(x, "__len2__", Tensor([-1]))
return abs_func
out = foo()
assert out == -1
def test_getattr_tuple_with_wrong_attr():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support tensor input.
Expectation: AttributeError.
"""
@ms_function
def foo(x):
abs_func = getattr(x, "shape")
return abs_func()
with pytest.raises(AttributeError) as err:
foo((1, 2, 3, 4))
assert "object has no attribute" in str(err.value)
def test_getattr_dict():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support dict input.
Expectation: No exception.
"""
@ms_function
def foo():
x = {"1": 1, "2": 2}
abs_func = getattr(x, "__len__")
return abs_func()
out = foo()
assert out == 2
def test_getattr_dict_2():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support dict input.
Expectation: No exception.
"""
@ms_function
def foo(x):
abs_func = getattr(x, "__len__")
return abs_func()
out = foo({"1": 1, "2": 2})
assert out == 2
def test_getattr_dict_with_concate_input():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support dict input.
Expectation: No exception.
"""
@ms_function
def foo():
x = {"1": 1, "2": 2}
attr_str = "__" + "len" + "__"
abs_func = getattr(x, attr_str)
return abs_func()
out = foo()
assert out == 2
def test_getattr_dict_with_concate_input_2():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support dict input.
Expectation: No exception.
"""
@ms_function
def foo(x):
attr_str = "__" + "len" + "__"
abs_func = getattr(x, attr_str)
return abs_func()
out = foo({"1": 1, "2": 2})
assert out == 2
def test_getattr_dict_with_wrong_attr():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support tensor input.
Expectation: AttributeError.
"""
@ms_function
def foo(x):
abs_func = getattr(x, "abs2")
return abs_func()
with pytest.raises(AttributeError) as err:
foo({"1": 1, "2": 2})
assert "object has no attribute" in str(err.value)
def test_getattr_dict_with_default():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support dict input with default.
Expectation: No exception.
"""
@ms_function
def foo():
x = {"1": 1, "2": 2}
abs_func = getattr(x, "__len__", Tensor([-1]))
return abs_func()
out = foo()
assert out == 2
def test_getattr_dict_with_default_2():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support dict input with default.
Expectation: No exception.
"""
@ms_function
def foo():
x = {"1": 1, "2": 2}
abs_func = getattr(x, "__len2__", Tensor([-1]))
return abs_func
out = foo()
assert out == -1
@ms_class
class MSClass1:
def __init__(self):
self.num0 = Tensor(0)
self.num1 = Tensor(1)
self.num2 = Tensor(2)
self.num3 = Tensor(3)
self.none = None
def test_getattr_ms_class():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support ms_class input.
Expectation: No exception.
"""
ms_obj = MSClass1()
@ms_function
def foo():
return getattr(ms_obj, "num1")
out = foo()
assert out == 1
def test_getattr_ms_class_with_concate_attr():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support ms_class input.
Expectation: No exception.
"""
ms_obj = MSClass1()
@ms_function
def foo():
ret = 0
nums = ["0", "1", "2", "3"]
for i in range(4):
attr_str = "num" + nums[i]
ret = ret + getattr(ms_obj, attr_str)
return ret
out = foo()
assert out == 6
def test_getattr_ms_class_with_default():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support ms_class input.
Expectation: No exception.
"""
ms_obj = MSClass1()
@ms_function
def foo():
return getattr(ms_obj, "none", 10)
out = foo()
assert out == 10
def test_getattr_ms_class_with_concate_attr_and_default():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support ms_class input.
Expectation: No exception.
"""
ms_obj = MSClass1()
@ms_function
def foo():
ret = 0
nums = ["0", "1", "2", "3", "4"]
for i in range(5):
attr_str = "num" + nums[i]
ret = ret + getattr(ms_obj, attr_str, Tensor([4]))
return ret
out = foo()
assert out == 10
def test_getattr_ms_class_with_wrong_attr():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support list input.
Expectation: AttributeError.
"""
ms_obj = MSClass1()
@ms_function
def foo():
abs_func = getattr(ms_obj, "abs2")
return abs_func()
with pytest.raises(AttributeError) as err:
foo()
assert "object has no attribute" in str(err.value)
def test_getattr_ms_class_with_wrong_attr_2():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support list input.
Expectation: AttributeError.
"""
ms_obj = MSClass1()
@ms_function
def foo():
abs_func = getattr(ms_obj, "none")
return abs_func
with pytest.raises(AttributeError) as err:
foo()
assert "object has no attribute" in str(err.value)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.a0 = Tensor([0])
self.a1 = Tensor([1])
self.a2 = Tensor([2])
self.a3 = Tensor([3])
def construct(self):
return self.a0
def test_getattr_cell_obj():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support cell object input.
Expectation: No exception.
"""
cell_obj = Net()
@ms_function
def foo():
return getattr(cell_obj, "a0")
out = foo()
assert out == 0
def test_getattr_cell_obj_concate_input():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support cell object input.
Expectation: No exception.
"""
cell_obj = Net()
@ms_function
def foo():
a = 0
attrs = ["0", "1", "2", "3"]
for attr in attrs:
a = a + getattr(cell_obj, "a" + attr)
return a
out = foo()
assert out == 6
def test_getattr_cell_obj_concate_input_and_default_value():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support cell object input.
Expectation: No exception.
"""
cell_obj = Net()
@ms_function
def foo():
a = 0
attrs = ["0", "1", "2", "3", "4"]
for attr in attrs:
a = a + getattr(cell_obj, "a" + attr, Tensor([4]))
return a
out = foo()
assert out == 10
def test_getattr_cell_obj_with_wrong_attr():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support tensor input.
Expectation: AttributeError.
"""
cell_obj = Net()
@ms_function
def foo():
abs_func = getattr(cell_obj, "foo")
return abs_func()
with pytest.raises(AttributeError) as err:
foo()
assert "object has no attribute" in str(err.value)
def test_getattr_numpy_array():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support numpy array input.
Expectation: TypeError
"""
@ms_function
def foo():
x = np.array([1, 2, 3, 4])
return getattr(x, "shape")[0]
with pytest.raises(TypeError) as err:
foo()
assert "Do not support to get attribute from interpret object." in str(err.value)
def test_getattr_numpy_array_2():
"""
Feature: Syntax getattr.
Description: Graph syntax getattr support numpy array input.
Expectation: TypeError
"""
@ms_function
def foo():
x = 1
return getattr(x, "shape", np.array([0, 1, 2, 3, 4]))
with pytest.raises(TypeError) as err:
foo()
assert "For 'getattr', the third input 'default' can not be interpreted object." in str(err.value)

View File

@ -0,0 +1,303 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test graph hasattr"""
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, ms_function, ms_class, context
context.set_context(mode=context.GRAPH_MODE)
def test_hasattr_tensor():
"""
Feature: Syntax hasattr.
Description: Graph syntax hasattr support tensor input.
Expectation: No exception.
"""
@ms_function
def foo(x):
return hasattr(x, "abs")
assert foo(Tensor([-1, -2, -3]))
def test_hasattr_tensor_2():
"""
Feature: Syntax hasattr.
Description: Graph syntax hasattr support tensor input.
Expectation: No exception.
"""
@ms_function
def foo(x):
return hasattr(x, "abs2")
assert not foo(Tensor([-1, -2, -3]))
def test_hasattr_list():
"""
Feature: Syntax hasattr.
Description: Graph syntax hasattr support list input.
Expectation: No exception.
"""
@ms_function
def foo():
x = [1, 2, 3, 4]
attr = "__" + "len" + "__"
return hasattr(x, attr)
assert foo()
def test_hasattr_list_2():
"""
Feature: Syntax hasattr.
Description: Graph syntax hasattr support list input.
Expectation: No exception.
"""
@ms_function
def foo():
x = [1, 2, 3, 4]
attr = "__" + "len2" + "__"
return hasattr(x, attr)
assert not foo()
def test_hasattr_tuple():
"""
Feature: Syntax hasattr.
Description: Graph syntax hasattr support tuple input.
Expectation: No exception.
"""
@ms_function
def foo():
x = (1, 2, 3, 4)
attr = "__" + "len" + "__"
return hasattr(x, attr)
assert foo()
def test_hasattr_tuple_2():
"""
Feature: Syntax hasattr.
Description: Graph syntax hasattr support tuple input.
Expectation: No exception.
"""
@ms_function
def foo():
x = (1, 2, 3, 4)
attr = "__" + "len2" + "__"
return hasattr(x, attr)
assert not foo()
def test_hasattr_dict():
"""
Feature: Syntax hasattr.
Description: Graph syntax hasattr support dict input.
Expectation: No exception.
"""
@ms_function
def foo():
x = {"1": 1, "2": 2}
attr = "__" + "len" + "__"
return hasattr(x, attr)
assert foo()
def test_hasattr_dict_2():
"""
Feature: Syntax hasattr.
Description: Graph syntax hasattr support dict input.
Expectation: No exception.
"""
@ms_function
def foo():
x = {"1": 1, "2": 2}
attr = "__" + "len2" + "__"
return hasattr(x, attr)
assert not foo()
@ms_class
class MSClass1:
def __init__(self):
self.num0 = Tensor(0)
self.num1 = Tensor(1)
self.num2 = Tensor(2)
self.num3 = Tensor(3)
self.none = None
def test_hasattr_ms_class():
"""
Feature: Syntax hasattr.
Description: Graph syntax hasattr support ms_class input.
Expectation: No exception.
"""
ms_obj = MSClass1()
@ms_function
def foo():
return hasattr(ms_obj, "num1")
assert foo()
def test_hasattr_ms_class_2():
"""
Feature: Syntax hasattr.
Description: Graph syntax hasattr support ms_class input.
Expectation: No exception.
"""
ms_obj = MSClass1()
@ms_function
def foo():
return hasattr(ms_obj, "none")
assert not foo()
def test_hasattr_ms_class_with_concate_attr():
"""
Feature: Syntax hasattr.
Description: Graph syntax hasattr support ms_class input.
Expectation: No exception.
"""
ms_obj = MSClass1()
@ms_function
def foo():
ret = 0
nums = ["0", "1", "2", "3", "4"]
for i in range(5):
attr_str = "num" + nums[i]
if hasattr(ms_obj, attr_str):
ret = ret + 1
return ret
out = foo()
assert out == 4
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.a0 = Tensor([0])
self.a1 = Tensor([1])
self.a2 = Tensor([2])
self.a3 = Tensor([3])
self.none = None
def construct(self):
return self.a0
def test_hasattr_cell_obj():
"""
Feature: Syntax hasattr.
Description: Graph syntax hasattr support cell object input.
Expectation: No exception.
"""
cell_obj = Net()
@ms_function
def foo():
return hasattr(cell_obj, "a0")
assert foo()
def test_hasattr_cell_obj_2():
"""
Feature: Syntax hasattr.
Description: Graph syntax hasattr support cell object input.
Expectation: No exception.
"""
cell_obj = Net()
@ms_function
def foo():
return hasattr(cell_obj, "none")
assert not foo()
def test_hasattr_cell_obj_concate_input():
"""
Feature: Syntax hasattr.
Description: Graph syntax hasattr support cell object input.
Expectation: No exception.
"""
cell_obj = Net()
@ms_function
def foo():
a = 0
attrs = ["0", "1", "2", "3", "4"]
for attr in attrs:
if hasattr(cell_obj, "a" + attr):
a = a + 1
return a
out = foo()
assert out == 4
def test_hasattr_numpy_array():
"""
Feature: Syntax hasattr.
Description: Graph syntax hasattr support numpy array input.
Expectation: TypeError
"""
@ms_function
def foo():
x = np.array([1, 2, 3, 4])
return hasattr(x, "shape")
assert foo()
def test_hasattr_numpy_array_2():
"""
Feature: Syntax hasattr.
Description: Graph syntax hasattr support numpy array input.
Expectation: TypeError
"""
@ms_function
def foo():
x = np.array([1, 2, 3, 4])
return hasattr(x, "shape2")
assert not foo()

View File

@ -27,8 +27,6 @@ context.set_context(mode=context.GRAPH_MODE)
def test_dtype_and_shape_as_attr():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
shape = x.shape
@ -62,8 +60,6 @@ def test_dtype_and_shape_as_attr_to_new_tensor():
def test_type_not_have_the_attr():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
shape = x.shapes
@ -71,14 +67,12 @@ def test_type_not_have_the_attr():
net = Net()
x = Tensor(np.ones([1, 2, 3], np.int32))
with pytest.raises(RuntimeError):
with pytest.raises(AttributeError):
net(x)
def test_type_not_have_the_method():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
shape = x.dtypes()
@ -86,5 +80,5 @@ def test_type_not_have_the_method():
net = Net()
x = Tensor(np.ones([1, 2, 3], np.int32))
with pytest.raises(RuntimeError):
with pytest.raises(AttributeError):
net(x)