!27728 [Fallback] Support numpy augassign method, subscript method and binary operations

Merge pull request !27728 from huangbingjian/fallback_parse
This commit is contained in:
i-robot 2021-12-22 01:01:46 +00:00 committed by Gitee
commit b1de53fac2
9 changed files with 317 additions and 64 deletions

View File

@ -580,7 +580,7 @@ AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &n
// Create apply node
MS_EXCEPTION_IF_NULL(block->func_graph());
auto new_node = block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
UpdateInterpretForUserNode(left_node, new_node);
UpdateInterpretForUserNode(new_node, {left_node, right_node});
return new_node;
}
@ -735,7 +735,7 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
bool need_unpack = need_unpack_args || need_unpack_keywords;
auto call_cnode = GenerateAnfNodeForCall(block, call_function_node, packed_arguments, group_arguments, need_unpack);
UpdateInterpretForUserNode(call_function_node, call_cnode);
UpdateInterpretForUserNode(call_cnode, call_function_node);
if (call_cnode->interpret_special_type() && need_fallback) {
call_cnode = HandleInterpret(block, call_cnode, node);
}
@ -883,7 +883,7 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
// Create the apply node
auto attr_cnode = block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
UpdateInterpretForUserNode(value_node, attr_cnode);
UpdateInterpretForUserNode(attr_cnode, value_node);
return attr_cnode;
}
@ -912,7 +912,7 @@ AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object
MS_EXCEPTION_IF_NULL(block);
AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]);
auto new_node = block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
UpdateInterpretForUserNode(left_node, new_node);
UpdateInterpretForUserNode(new_node, {left_node, right_node});
return new_node;
}
@ -969,7 +969,7 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p
std::vector<AnfNodePtr> call_graph_nodes{switch_app};
auto switch_app_call = block_fg->NewCNodeInOrder(std::move(call_graph_nodes));
UpdateInterpretForUserNode(test_node, switch_app_call);
UpdateInterpretForUserNode(switch_app_call, {test_node, rest_node});
return switch_app_call;
}
}
@ -1066,7 +1066,9 @@ AnfNodePtr Parser::ParseSubscript(const FunctionBlockPtr &block, const py::objec
value = HandleInterpret(block, value, value_node);
AnfNodePtr slice = ParseExprNode(block, slice_node);
slice = HandleInterpret(block, slice, slice_node);
return block->func_graph()->NewCNodeInOrder({op_getitem, value, slice});
auto new_node = block->func_graph()->NewCNodeInOrder({op_getitem, value, slice});
UpdateInterpretForUserNode(new_node, value);
return new_node;
}
// Process a slice, get the slice value
@ -1120,7 +1122,7 @@ AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object
AnfNodePtr operand_node = ParseExprNode(block, operand);
operand_node = HandleInterpret(block, operand_node, operand);
auto new_node = block->func_graph()->NewCNodeInOrder({op_node, operand_node});
UpdateInterpretForUserNode(operand_node, new_node);
UpdateInterpretForUserNode(new_node, operand_node);
return new_node;
}
@ -1147,6 +1149,29 @@ AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &no
return ParseDictByKeysAndValues(block, key_nodes, value_nodes);
}
AnfNodePtr Parser::HandleInterpretForAugassign(const FunctionBlockPtr &block, const AnfNodePtr &augassign_node,
const py::object &op_object, const py::object &target_object,
const py::object &value_object) {
// The fallback feature is enabled in default.
static const auto use_fallback = (support_fallback() != "0");
if (!use_fallback || !augassign_node->interpret()) {
return augassign_node;
}
std::string op_text =
py::cast<std::string>(ast()->CallParseModFunction(PYTHON_PARSE_GET_OPERATION_SYMBOL, op_object));
// Check the symbol in the Augasssign expression.
if (op_text.empty()) {
MS_LOG(EXCEPTION)
<< "Invalid augasssign operator, only support `+=`, `-=`, `*=`, `/=`, `%=`, `**=`, `//=`, `<<=`, `>>=`, `^=`.";
}
const auto target_text = py::cast<std::string>(ast()->GetAstNodeText(target_object));
const auto value_text = py::cast<std::string>(ast()->GetAstNodeText(value_object));
std::string script_text = target_text + op_text + value_text;
return MakeInterpretNode(block, augassign_node, script_text);
}
// Process a augment assign such as a += b or mat[stride_slice] += b.
FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast AugAssign";
@ -1183,7 +1208,12 @@ FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py:
if (target_node == nullptr) {
MS_LOG(EXCEPTION) << "Can not get target node ";
}
CNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, value_node});
AnfNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, value_node});
// Check whether the augassign expression needs to be interpreted.
UpdateInterpretForUserNode(augassign_app, {target_node, value_node});
augassign_app = HandleInterpretForAugassign(block, augassign_app, op_object, target_object, value_object);
WriteAssignVars(block, target_object, augassign_app);
return block;
}
@ -1900,7 +1930,7 @@ void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &ta
}
}
void Parser::UpdateInterpretForUserNode(const AnfNodePtr &node, const AnfNodePtr &user_node) {
void Parser::UpdateInterpretForUserNode(const AnfNodePtr &user_node, const AnfNodePtr &node) {
// Do not handle user node with internal type such as Tensor.abs().
bool interpret_without_internal = IsPrimitiveCNode(node, prim::kPrimPyInterpret) && !node->interpret_internal_type();
if (node->interpret() || interpret_without_internal) {
@ -1914,6 +1944,12 @@ void Parser::UpdateInterpretForUserNode(const AnfNodePtr &node, const AnfNodePtr
}
}
void Parser::UpdateInterpretForUserNode(const AnfNodePtr &user_node, const std::vector<AnfNodePtr> &nodes) {
for (auto &node : nodes) {
UpdateInterpretForUserNode(user_node, node);
}
}
bool Parser::IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph) {
// Check global parameters.
@ -1945,8 +1981,12 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP
if (!use_fallback || !value_node->interpret()) {
return value_node;
}
const auto script_text = py::cast<std::string>(ast()->GetAstNodeText(value_object));
return MakeInterpretNode(block, value_node, script_text);
}
AnfNodePtr Parser::MakeInterpretNode(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
const string &script_text) {
// Check if script_text is in global/local params.
py::dict global_dict = block->global_py_params();
auto [keys, values] = block->local_py_params();

View File

@ -191,10 +191,17 @@ class Parser {
bool IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph);
// Set the interpret flag for the node calling the interpret node.
void UpdateInterpretForUserNode(const AnfNodePtr &node, const AnfNodePtr &user_node);
void UpdateInterpretForUserNode(const AnfNodePtr &user_node, const AnfNodePtr &node);
void UpdateInterpretForUserNode(const AnfNodePtr &user_node, const std::vector<AnfNodePtr> &nodes);
// Make interpret node.
AnfNodePtr MakeInterpretNode(const FunctionBlockPtr &block, const AnfNodePtr &value_node, const string &script_text);
// Check if the node need interpreting.
AnfNodePtr HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
const py::object &value_object);
// Handle interpret for augassign expression.
AnfNodePtr HandleInterpretForAugassign(const FunctionBlockPtr &block, const AnfNodePtr &augassign_node,
const py::object &op_object, const py::object &target_object,
const py::object &value_object);
// Generate argument nodes for ast function node
void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &function_node);

View File

@ -82,6 +82,7 @@ const char PYTHON_PARSE_GET_NODE_TYPE[] = "get_node_type";
const char PYTHON_PARSE_GET_AST_TYPE[] = "get_ast_type";
const char PYTHON_PARSE_GET_NAMESPACE_SYMBOL[] = "get_namespace_symbol";
const char PYTHON_PARSE_GET_AST_NAMESPACE_SYMBOL[] = "get_ast_namespace_symbol";
const char PYTHON_PARSE_GET_OPERATION_SYMBOL[] = "get_operation_symbol";
const char PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL[] = "get_operation_namespace_symbol";
const char PYTHON_PARSE_GET_BUILTIN_NAMESPACE_SYMBOL[] = "get_builtin_namespace_symbol";
const char PYTHON_PARSE_GET_LOCATION[] = "get_location";

View File

@ -20,9 +20,10 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
get_bprop_method_of_class, get_class_instance_type, get_class_member_namespace_symbol,
create_slice_obj, get_dataclass_attributes, get_dataclass_methods, get_obj_id,
get_module_namespace, get_obj_type, get_object_key, get_ast_type, get_node_type,
get_args, get_args_default_values, get_ast_namespace_symbol, get_operation_namespace_symbol,
get_parse_method_of_class, get_scope_name, eval_script, expand_expr_statement,
is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor, get_object_description)
get_args, get_args_default_values, get_ast_namespace_symbol, get_operation_symbol,
get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name, eval_script,
expand_expr_statement, is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor,
get_object_description)
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type',
@ -30,4 +31,5 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
'get_args', 'get_obj_type', 'create_instance', 'is_supported_create_instance_type',
'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser',
'get_dataclass_attributes', 'get_dataclass_methods', 'get_dataclass_methods', 'get_scope_name',
'eval_script', 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement']
'eval_script', 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement',
'generate_scope', 'get_operation_symbol']

View File

@ -33,7 +33,7 @@ from mindspore import ops
from mindspore.common.api import _MindsporeFunctionExecutor
from mindspore.common.dtype import pytype_to_dtype
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
from .resources import parse_object_map, ops_symbol_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
# define return value
RET_SUCCESS = 0
@ -456,6 +456,13 @@ def get_ast_namespace_symbol(obj):
return ops_info
def get_operation_symbol(obj):
"""Get obj operation symbol."""
ops_symbol = ops_symbol_map.get(type(obj), SYMBOL_UNDEFINE)
logger.debug("ops symbol: %s", ops_symbol)
return ops_symbol
def get_operation_namespace_symbol(var: str):
"""Get operation namespace and symbol."""
ops_info = (trope_ns, var)
@ -685,6 +692,9 @@ class Parser:
if name == 'mindspore.numpy':
logger.debug(f"Found 'mindspore.numpy' namespace.")
return True
if name == 'mindspore.context':
logger.debug(f"Found 'mindspore.context' namespace.")
return True
# Check `builtins` namespace.
if hasattr(value, '__module__'): # Not types.ModuleType

View File

@ -75,6 +75,24 @@ parse_object_map = {
SYMBOL_UNDEFINE: (None, 'undefine'),
}
# Operation symbols corresponding to ast grammar
ops_symbol_map = {
# ast grammar
ast.Add: '+',
ast.Sub: '-',
ast.Mult: '*',
ast.Div: '/',
ast.FloorDiv: '//',
ast.Mod: '%',
ast.Pow: '**',
ast.LShift: '<<',
ast.RShift: '>>',
ast.BitXor: '^',
# undefined type
SYMBOL_UNDEFINE: '',
}
# Escape an object to another object, eg: system function(len,xxx)
# Some space set aside for readability of code
convert_object_map = {

View File

@ -201,3 +201,25 @@ def test_slice_func():
a = Tensor(np.arange(60).reshape(3, 4, 5), dtype=mstype.float32)
b = Tensor([1], dtype=mstype.float32)
print(slice_func(a, b))
def test_context():
"""
Feature: JIT Fallback
Description: Test context in graph.
Expectation: No exception.
"""
class ContextNet(nn.Cell):
def __init__(self):
super(ContextNet, self).__init__()
self.mode = context.get_context("mode")
def construct(self):
out = 1
if self.mode == context.GRAPH_MODE:
out = 2
return out
net = ContextNet()
out = net()
print(out)

View File

@ -31,8 +31,8 @@ def test_np_array_1():
a = np.array([1, 2, 3])
return Tensor(a)
res = np_array_1()
expect_res = Tensor(np.array([1, 2, 3]))
assert np.all(res.asnumpy() == expect_res.asnumpy())
expect_res = np.array([1, 2, 3])
assert np.all(res.asnumpy() == expect_res)
def test_np_array_2():
@ -46,8 +46,8 @@ def test_np_array_2():
a = np.array([[1, 2], [3, 4]])
return Tensor(a)
res = np_array_2()
expect_res = Tensor(np.array([[1, 2], [3, 4]]))
assert np.all(res.asnumpy() == expect_res.asnumpy())
expect_res = np.array([[1, 2], [3, 4]])
assert np.all(res.asnumpy() == expect_res)
def test_np_array_3():
@ -61,8 +61,8 @@ def test_np_array_3():
a = np.array([1, 2, 3, 4, 5], ndmin=2)
return Tensor(a)
res = np_array_3()
expect_res = Tensor(np.array([[1, 2, 3, 4, 5]]))
assert np.all(res.asnumpy() == expect_res.asnumpy())
expect_res = np.array([[1, 2, 3, 4, 5]])
assert np.all(res.asnumpy() == expect_res)
def test_np_array_4():
@ -76,7 +76,8 @@ def test_np_array_4():
a = np.array([1, 2, 3], dtype=complex)
return Tensor(a)
res = np_array_4()
assert np.all(res.asnumpy() == Tensor(np.array([1+0j, 2+0j, 3+0j])).asnumpy())
expect_res = np.array([1+0j, 2+0j, 3+0j])
assert np.all(res.asnumpy() == expect_res)
def test_np_dtype_1():
@ -90,7 +91,8 @@ def test_np_dtype_1():
t = np.dtype(np.int32)
return Tensor(np.array([1, 2, 3], dtype=t))
res = np_dtype_1()
assert np.all(res.asnumpy() == Tensor(np.array([1, 2, 3], dtype=np.int32)).asnumpy())
expect_res = np.array([1, 2, 3], dtype=np.int32)
assert np.all(res.asnumpy() == expect_res)
def test_np_dtype_2():
@ -104,7 +106,8 @@ def test_np_dtype_2():
t = np.dtype('i4')
return Tensor(np.array([1, 2, 3], dtype=t))
res = np_dtype_2()
assert np.all(res.asnumpy() == Tensor(np.array([1, 2, 3], dtype=np.int32)).asnumpy())
expect_res = np.array([1, 2, 3], dtype=np.int32)
assert np.all(res.asnumpy() == expect_res)
def test_np_array_ndim():
@ -180,8 +183,8 @@ def test_np_empty_zeros_ones():
z = np.ones(x.shape, dtype=np.int)
return Tensor(y + z)
res = np_empty_zeros_ones()
except_res = Tensor(np.ones([3, 2], dtype=np.int))
assert np.all(res.asnumpy() == except_res.asnumpy())
except_res = np.ones([3, 2], dtype=np.int)
assert np.all(res.asnumpy() == except_res)
def test_np_asarray_list():
@ -196,8 +199,8 @@ def test_np_asarray_list():
y = np.asarray(x)
return Tensor(y)
res = np_asarray_list()
except_res = Tensor(np.asarray([1, 2, 3]))
assert np.all(res.asnumpy() == except_res.asnumpy())
except_res = np.asarray([1, 2, 3])
assert np.all(res.asnumpy() == except_res)
def test_np_asarray_tuple():
@ -212,8 +215,8 @@ def test_np_asarray_tuple():
y = np.asarray(x)
return Tensor(y)
res = np_asarray_tuple()
except_res = Tensor(np.asarray((1, 2, 3)))
assert np.all(res.asnumpy() == except_res.asnumpy())
except_res = np.asarray((1, 2, 3))
assert np.all(res.asnumpy() == except_res)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@ -245,8 +248,8 @@ def test_np_fromiter():
x = np.fromiter(it, dtype=float)
return Tensor(x)
res = np_fromiter()
except_res = Tensor(np.asarray([0., 1., 2., 3., 4.]))
assert np.all(res.asnumpy() == except_res.asnumpy())
except_res = np.asarray([0., 1., 2., 3., 4.])
assert np.all(res.asnumpy() == except_res)
def test_np_arange():
@ -261,8 +264,8 @@ def test_np_arange():
y = np.arange(10, 20, 2)
return Tensor(x + y)
res = np_arange()
except_res = Tensor(np.asarray([10., 13., 16., 19., 22.]))
assert np.all(res.asnumpy() == except_res.asnumpy())
except_res = np.asarray([10., 13., 16., 19., 22.])
assert np.all(res.asnumpy() == except_res)
def test_np_logspace():
@ -276,8 +279,8 @@ def test_np_logspace():
a = np.logspace(0, 9, 10, base=2)
return Tensor(a)
res = np_logspace()
except_res = Tensor(np.array([1., 2., 4., 8., 16., 32., 64., 128., 256., 512.]))
assert np.all(res.asnumpy() == except_res.asnumpy())
except_res = np.array([1., 2., 4., 8., 16., 32., 64., 128., 256., 512.])
assert np.all(res.asnumpy() == except_res)
def test_np_array_shape():
@ -352,6 +355,21 @@ def test_np_binop():
assert np.all(res.asnumpy() == np.array([5, 7, 9]))
def test_np_binop_2():
"""
Feature: JIT Fallback
Description: Test numpy's binary operation in graph mode.
Expectation: No exception.
"""
@ms_function
def np_binop():
a = np.int_(1)
b = 4 + a
return Tensor(b)
res = np_binop()
assert res == 5
def test_np_compare():
"""
Feature: JIT Fallback
@ -368,6 +386,22 @@ def test_np_compare():
assert np.all(res.asnumpy() == np.array([True, False, False]))
def test_np_compare_2():
"""
Feature: JIT Fallback
Description: Test numpy's compare operation in graph mode.
Expectation: No exception.
"""
@ms_function
def np_compare():
a = 1
b = np.int_(3)
c = a < b
return Tensor(c)
res = np_compare()
assert res
def test_np_bool_and():
"""
Feature: JIT Fallback
@ -381,7 +415,7 @@ def test_np_bool_and():
c = a and b
return Tensor(c)
res = np_bool_and()
assert not res.asnumpy()
assert not res
def test_np_bool_or():
@ -397,7 +431,21 @@ def test_np_bool_or():
c = a or b
return Tensor(c)
res = np_bool_or()
assert res.asnumpy()
assert res
def test_np_bool_or_2():
"""
Feature: JIT Fallback
Description: Test OR operation in graph mode.
Expectation: No exception.
"""
@ms_function
def np_bool_or():
out = 0 or np.bool_(True)
return Tensor(out)
res = np_bool_or()
assert res
def test_np_bool_not():
@ -412,4 +460,90 @@ def test_np_bool_not():
b = not a
return Tensor(b)
res = np_bool_not()
assert not res.asnumpy()
assert not res
def test_np_augassign():
"""
Feature: JIT Fallback
Description: Test augassign method in graph mode.
Expectation: No exception.
"""
@ms_function
def np_augassign():
value_add = np.array([1, 2, 3])
value_add += np.array([4, 5, 6])
value_sub = np.array([5, 5, 5])
value_sub -= np.array([1, 2, 3])
value_mul = np.int_(2)
value_mul *= np.int_(3)
value_div = np.int_(10)
value_div /= np.int_(5)
value_floordiv = np.int_(5)
value_floordiv //= np.int_(2)
return Tensor(value_add), Tensor(value_sub), Tensor(value_mul), Tensor(value_div), Tensor(value_floordiv)
out_add, out_sub, out_mul, out_div, out_floordiv = np_augassign()
assert np.all(out_add.asnumpy() == np.array([5, 7, 9]))
assert np.all(out_sub.asnumpy() == np.array([4, 3, 2]))
assert out_mul == 6
assert out_div == 2
assert out_floordiv == 2
def test_np_augassign_2():
"""
Feature: JIT Fallback
Description: Test augassign method in graph mode.
Expectation: No exception.
"""
@ms_function
def np_augassign():
value_mod = np.int_(5)
value_mod %= np.int_(2)
value_pow = np.int_(3)
value_pow **= np.int_(2)
value_lshift = np.int_(4)
value_lshift <<= 1
value_rshift = np.int_(4)
value_rshift >>= 1
value_bitxor = np.int_(0)
value_bitxor ^= 1
return Tensor(value_mod), Tensor(value_pow), Tensor(value_lshift), Tensor(value_rshift), Tensor(value_bitxor)
out_mod, out_pow, out_lshift, out_rshift, out_bitxor = np_augassign()
assert out_mod == 1
assert out_pow == 9
assert out_lshift == 8
assert out_rshift == 2
assert out_bitxor == 1
def test_np_subscript():
"""
Feature: JIT Fallback
Description: Test subscript method in graph mode.
Expectation: No exception.
"""
@ms_function
def np_subscript():
a = np.array([1, 2, 3])
b = a[np.int32(1)]
return Tensor(b)
res = np_subscript()
assert res == 2
def test_np_slice():
"""
Feature: JIT Fallback
Description: Test slice method in graph mode.
Expectation: No exception.
"""
@ms_function
def np_slice():
a = np.arange(10)
b = a[1:5]
return Tensor(b)
res = np_slice()
assert np.all(res.asnumpy() == np.array([1, 2, 3, 4]))

View File

@ -197,6 +197,35 @@ def test_fallback_tensor_astype():
print(foo())
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_fallback_tensor_asnumpy():
"""
Feature: JIT Fallback
Description: Test Tensor.asnumpy() in graph mode.
Expectation: No exception.
"""
@ms_function
def foo():
me_x = Tensor(np.arange(0, 6).reshape(2, 3))
np_x = me_x.asnumpy()
return Tensor(np_x)
print(foo())
def test_fallback_tensor_from_numpy():
"""
Feature: JIT Fallback
Description: Test Tensor.from_numpy() in graph mode.
Expectation: No exception.
"""
@ms_function
def foo():
np_x = np.array([1, 2])
me_x = Tensor.from_numpy(np_x)
return me_x
print(foo())
# EvalCNode: This may be not defined, or it can't be a operator.
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_tensor_add():
@ -231,9 +260,6 @@ def test_fallback_tensor_binop():
Expectation: No exception.
"""
class BinOpNet(nn.Cell):
def __init__(self):
super(BinOpNet, self).__init__()
def construct(self):
np_array = np.array(9)
res = Tensor(np_array) + Tensor(np_array)
@ -250,9 +276,6 @@ def test_fallback_tensor_compare():
Expectation: No exception.
"""
class CompareNet(nn.Cell):
def __init__(self):
super(CompareNet, self).__init__()
def construct(self):
np_array_1 = np.array(1)
np_array_2 = np.array(2)
@ -270,9 +293,6 @@ def test_fallback_tensor_not():
Expectation: No exception.
"""
class NotNet(nn.Cell):
def __init__(self):
super(NotNet, self).__init__()
def construct(self):
np_array_1 = np.array(True, dtype=np.bool_)
res = not Tensor(np_array_1)
@ -290,9 +310,6 @@ def test_fallback_tensor_and():
Expectation: No exception.
"""
class AndNet(nn.Cell):
def __init__(self):
super(AndNet, self).__init__()
def construct(self):
np_array_1 = np.array(True, dtype=np.bool_)
np_array_2 = np.array(False, dtype=np.bool_)
@ -311,9 +328,6 @@ def test_fallback_tensor_or():
Expectation: No exception.
"""
class OrNet(nn.Cell):
def __init__(self):
super(OrNet, self).__init__()
def construct(self):
np_array_1 = np.array(True, dtype=np.bool_)
np_array_2 = np.array(False, dtype=np.bool_)
@ -332,9 +346,6 @@ def test_fallback_tensor_augassign():
Expectation: No exception.
"""
class OrNet(nn.Cell):
def __init__(self):
super(OrNet, self).__init__()
def construct(self):
np_array_1 = np.array(1)
np_array_2 = np.array(2)
@ -354,9 +365,6 @@ def test_fallback_tensor_subscript():
Expectation: No exception.
"""
class SubScriptNet(nn.Cell):
def __init__(self):
super(SubScriptNet, self).__init__()
def construct(self):
np_array_1 = np.array([1, 2, 3, 4, 5])
np_array_2 = np.array(2)
@ -375,9 +383,6 @@ def test_fallback_tensor_if():
Expectation: No exception.
"""
class IfNet(nn.Cell):
def __init__(self):
super(IfNet, self).__init__()
def construct(self):
np_array_1 = np.array(1)
if Tensor(np_array_1):
@ -387,3 +392,17 @@ def test_fallback_tensor_if():
net = IfNet()
res = net()
print("res:", res)
def test_fallback_tensor_slice():
"""
Feature: JIT Fallback
Description: support interpreted nodes in slice.
Expectation: No exception.
"""
@ms_function
def foo():
array = np.arange(10)
out = Tensor(array)[1:5]
return out
print(foo())