forked from mindspore-Ecosystem/mindspore
!27728 [Fallback] Support numpy augassign method, subscript method and binary operations
Merge pull request !27728 from huangbingjian/fallback_parse
This commit is contained in:
commit
b1de53fac2
|
@ -580,7 +580,7 @@ AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &n
|
|||
// Create apply node
|
||||
MS_EXCEPTION_IF_NULL(block->func_graph());
|
||||
auto new_node = block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
|
||||
UpdateInterpretForUserNode(left_node, new_node);
|
||||
UpdateInterpretForUserNode(new_node, {left_node, right_node});
|
||||
return new_node;
|
||||
}
|
||||
|
||||
|
@ -735,7 +735,7 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
|
|||
bool need_unpack = need_unpack_args || need_unpack_keywords;
|
||||
|
||||
auto call_cnode = GenerateAnfNodeForCall(block, call_function_node, packed_arguments, group_arguments, need_unpack);
|
||||
UpdateInterpretForUserNode(call_function_node, call_cnode);
|
||||
UpdateInterpretForUserNode(call_cnode, call_function_node);
|
||||
if (call_cnode->interpret_special_type() && need_fallback) {
|
||||
call_cnode = HandleInterpret(block, call_cnode, node);
|
||||
}
|
||||
|
@ -883,7 +883,7 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
|
|||
|
||||
// Create the apply node
|
||||
auto attr_cnode = block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
|
||||
UpdateInterpretForUserNode(value_node, attr_cnode);
|
||||
UpdateInterpretForUserNode(attr_cnode, value_node);
|
||||
return attr_cnode;
|
||||
}
|
||||
|
||||
|
@ -912,7 +912,7 @@ AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object
|
|||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]);
|
||||
auto new_node = block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
|
||||
UpdateInterpretForUserNode(left_node, new_node);
|
||||
UpdateInterpretForUserNode(new_node, {left_node, right_node});
|
||||
return new_node;
|
||||
}
|
||||
|
||||
|
@ -969,7 +969,7 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p
|
|||
|
||||
std::vector<AnfNodePtr> call_graph_nodes{switch_app};
|
||||
auto switch_app_call = block_fg->NewCNodeInOrder(std::move(call_graph_nodes));
|
||||
UpdateInterpretForUserNode(test_node, switch_app_call);
|
||||
UpdateInterpretForUserNode(switch_app_call, {test_node, rest_node});
|
||||
return switch_app_call;
|
||||
}
|
||||
}
|
||||
|
@ -1066,7 +1066,9 @@ AnfNodePtr Parser::ParseSubscript(const FunctionBlockPtr &block, const py::objec
|
|||
value = HandleInterpret(block, value, value_node);
|
||||
AnfNodePtr slice = ParseExprNode(block, slice_node);
|
||||
slice = HandleInterpret(block, slice, slice_node);
|
||||
return block->func_graph()->NewCNodeInOrder({op_getitem, value, slice});
|
||||
auto new_node = block->func_graph()->NewCNodeInOrder({op_getitem, value, slice});
|
||||
UpdateInterpretForUserNode(new_node, value);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
// Process a slice, get the slice value
|
||||
|
@ -1120,7 +1122,7 @@ AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object
|
|||
AnfNodePtr operand_node = ParseExprNode(block, operand);
|
||||
operand_node = HandleInterpret(block, operand_node, operand);
|
||||
auto new_node = block->func_graph()->NewCNodeInOrder({op_node, operand_node});
|
||||
UpdateInterpretForUserNode(operand_node, new_node);
|
||||
UpdateInterpretForUserNode(new_node, operand_node);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
|
@ -1147,6 +1149,29 @@ AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &no
|
|||
return ParseDictByKeysAndValues(block, key_nodes, value_nodes);
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::HandleInterpretForAugassign(const FunctionBlockPtr &block, const AnfNodePtr &augassign_node,
|
||||
const py::object &op_object, const py::object &target_object,
|
||||
const py::object &value_object) {
|
||||
// The fallback feature is enabled in default.
|
||||
static const auto use_fallback = (support_fallback() != "0");
|
||||
if (!use_fallback || !augassign_node->interpret()) {
|
||||
return augassign_node;
|
||||
}
|
||||
|
||||
std::string op_text =
|
||||
py::cast<std::string>(ast()->CallParseModFunction(PYTHON_PARSE_GET_OPERATION_SYMBOL, op_object));
|
||||
// Check the symbol in the Augasssign expression.
|
||||
if (op_text.empty()) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "Invalid augasssign operator, only support `+=`, `-=`, `*=`, `/=`, `%=`, `**=`, `//=`, `<<=`, `>>=`, `^=`.";
|
||||
}
|
||||
|
||||
const auto target_text = py::cast<std::string>(ast()->GetAstNodeText(target_object));
|
||||
const auto value_text = py::cast<std::string>(ast()->GetAstNodeText(value_object));
|
||||
std::string script_text = target_text + op_text + value_text;
|
||||
return MakeInterpretNode(block, augassign_node, script_text);
|
||||
}
|
||||
|
||||
// Process a augment assign such as a += b or mat[stride_slice] += b.
|
||||
FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast AugAssign";
|
||||
|
@ -1183,7 +1208,12 @@ FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py:
|
|||
if (target_node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Can not get target node ";
|
||||
}
|
||||
CNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, value_node});
|
||||
AnfNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, value_node});
|
||||
|
||||
// Check whether the augassign expression needs to be interpreted.
|
||||
UpdateInterpretForUserNode(augassign_app, {target_node, value_node});
|
||||
augassign_app = HandleInterpretForAugassign(block, augassign_app, op_object, target_object, value_object);
|
||||
|
||||
WriteAssignVars(block, target_object, augassign_app);
|
||||
return block;
|
||||
}
|
||||
|
@ -1900,7 +1930,7 @@ void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &ta
|
|||
}
|
||||
}
|
||||
|
||||
void Parser::UpdateInterpretForUserNode(const AnfNodePtr &node, const AnfNodePtr &user_node) {
|
||||
void Parser::UpdateInterpretForUserNode(const AnfNodePtr &user_node, const AnfNodePtr &node) {
|
||||
// Do not handle user node with internal type such as Tensor.abs().
|
||||
bool interpret_without_internal = IsPrimitiveCNode(node, prim::kPrimPyInterpret) && !node->interpret_internal_type();
|
||||
if (node->interpret() || interpret_without_internal) {
|
||||
|
@ -1914,6 +1944,12 @@ void Parser::UpdateInterpretForUserNode(const AnfNodePtr &node, const AnfNodePtr
|
|||
}
|
||||
}
|
||||
|
||||
void Parser::UpdateInterpretForUserNode(const AnfNodePtr &user_node, const std::vector<AnfNodePtr> &nodes) {
|
||||
for (auto &node : nodes) {
|
||||
UpdateInterpretForUserNode(user_node, node);
|
||||
}
|
||||
}
|
||||
|
||||
bool Parser::IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
|
||||
const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph) {
|
||||
// Check global parameters.
|
||||
|
@ -1945,8 +1981,12 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP
|
|||
if (!use_fallback || !value_node->interpret()) {
|
||||
return value_node;
|
||||
}
|
||||
|
||||
const auto script_text = py::cast<std::string>(ast()->GetAstNodeText(value_object));
|
||||
return MakeInterpretNode(block, value_node, script_text);
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::MakeInterpretNode(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
|
||||
const string &script_text) {
|
||||
// Check if script_text is in global/local params.
|
||||
py::dict global_dict = block->global_py_params();
|
||||
auto [keys, values] = block->local_py_params();
|
||||
|
|
|
@ -191,10 +191,17 @@ class Parser {
|
|||
bool IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
|
||||
const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph);
|
||||
// Set the interpret flag for the node calling the interpret node.
|
||||
void UpdateInterpretForUserNode(const AnfNodePtr &node, const AnfNodePtr &user_node);
|
||||
void UpdateInterpretForUserNode(const AnfNodePtr &user_node, const AnfNodePtr &node);
|
||||
void UpdateInterpretForUserNode(const AnfNodePtr &user_node, const std::vector<AnfNodePtr> &nodes);
|
||||
// Make interpret node.
|
||||
AnfNodePtr MakeInterpretNode(const FunctionBlockPtr &block, const AnfNodePtr &value_node, const string &script_text);
|
||||
// Check if the node need interpreting.
|
||||
AnfNodePtr HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
|
||||
const py::object &value_object);
|
||||
// Handle interpret for augassign expression.
|
||||
AnfNodePtr HandleInterpretForAugassign(const FunctionBlockPtr &block, const AnfNodePtr &augassign_node,
|
||||
const py::object &op_object, const py::object &target_object,
|
||||
const py::object &value_object);
|
||||
|
||||
// Generate argument nodes for ast function node
|
||||
void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &function_node);
|
||||
|
|
|
@ -82,6 +82,7 @@ const char PYTHON_PARSE_GET_NODE_TYPE[] = "get_node_type";
|
|||
const char PYTHON_PARSE_GET_AST_TYPE[] = "get_ast_type";
|
||||
const char PYTHON_PARSE_GET_NAMESPACE_SYMBOL[] = "get_namespace_symbol";
|
||||
const char PYTHON_PARSE_GET_AST_NAMESPACE_SYMBOL[] = "get_ast_namespace_symbol";
|
||||
const char PYTHON_PARSE_GET_OPERATION_SYMBOL[] = "get_operation_symbol";
|
||||
const char PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL[] = "get_operation_namespace_symbol";
|
||||
const char PYTHON_PARSE_GET_BUILTIN_NAMESPACE_SYMBOL[] = "get_builtin_namespace_symbol";
|
||||
const char PYTHON_PARSE_GET_LOCATION[] = "get_location";
|
||||
|
|
|
@ -20,9 +20,10 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
|
|||
get_bprop_method_of_class, get_class_instance_type, get_class_member_namespace_symbol,
|
||||
create_slice_obj, get_dataclass_attributes, get_dataclass_methods, get_obj_id,
|
||||
get_module_namespace, get_obj_type, get_object_key, get_ast_type, get_node_type,
|
||||
get_args, get_args_default_values, get_ast_namespace_symbol, get_operation_namespace_symbol,
|
||||
get_parse_method_of_class, get_scope_name, eval_script, expand_expr_statement,
|
||||
is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor, get_object_description)
|
||||
get_args, get_args_default_values, get_ast_namespace_symbol, get_operation_symbol,
|
||||
get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name, eval_script,
|
||||
expand_expr_statement, is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor,
|
||||
get_object_description)
|
||||
|
||||
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
|
||||
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type',
|
||||
|
@ -30,4 +31,5 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
|
|||
'get_args', 'get_obj_type', 'create_instance', 'is_supported_create_instance_type',
|
||||
'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser',
|
||||
'get_dataclass_attributes', 'get_dataclass_methods', 'get_dataclass_methods', 'get_scope_name',
|
||||
'eval_script', 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement']
|
||||
'eval_script', 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement',
|
||||
'generate_scope', 'get_operation_symbol']
|
||||
|
|
|
@ -33,7 +33,7 @@ from mindspore import ops
|
|||
from mindspore.common.api import _MindsporeFunctionExecutor
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
|
||||
from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
|
||||
from .resources import parse_object_map, ops_symbol_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
|
||||
|
||||
# define return value
|
||||
RET_SUCCESS = 0
|
||||
|
@ -456,6 +456,13 @@ def get_ast_namespace_symbol(obj):
|
|||
return ops_info
|
||||
|
||||
|
||||
def get_operation_symbol(obj):
|
||||
"""Get obj operation symbol."""
|
||||
ops_symbol = ops_symbol_map.get(type(obj), SYMBOL_UNDEFINE)
|
||||
logger.debug("ops symbol: %s", ops_symbol)
|
||||
return ops_symbol
|
||||
|
||||
|
||||
def get_operation_namespace_symbol(var: str):
|
||||
"""Get operation namespace and symbol."""
|
||||
ops_info = (trope_ns, var)
|
||||
|
@ -685,6 +692,9 @@ class Parser:
|
|||
if name == 'mindspore.numpy':
|
||||
logger.debug(f"Found 'mindspore.numpy' namespace.")
|
||||
return True
|
||||
if name == 'mindspore.context':
|
||||
logger.debug(f"Found 'mindspore.context' namespace.")
|
||||
return True
|
||||
|
||||
# Check `builtins` namespace.
|
||||
if hasattr(value, '__module__'): # Not types.ModuleType
|
||||
|
|
|
@ -75,6 +75,24 @@ parse_object_map = {
|
|||
SYMBOL_UNDEFINE: (None, 'undefine'),
|
||||
}
|
||||
|
||||
# Operation symbols corresponding to ast grammar
|
||||
ops_symbol_map = {
|
||||
# ast grammar
|
||||
ast.Add: '+',
|
||||
ast.Sub: '-',
|
||||
ast.Mult: '*',
|
||||
ast.Div: '/',
|
||||
ast.FloorDiv: '//',
|
||||
ast.Mod: '%',
|
||||
ast.Pow: '**',
|
||||
ast.LShift: '<<',
|
||||
ast.RShift: '>>',
|
||||
ast.BitXor: '^',
|
||||
|
||||
# undefined type
|
||||
SYMBOL_UNDEFINE: '',
|
||||
}
|
||||
|
||||
# Escape an object to another object, eg: system function(len,xxx)
|
||||
# Some space set aside for readability of code
|
||||
convert_object_map = {
|
||||
|
|
|
@ -201,3 +201,25 @@ def test_slice_func():
|
|||
a = Tensor(np.arange(60).reshape(3, 4, 5), dtype=mstype.float32)
|
||||
b = Tensor([1], dtype=mstype.float32)
|
||||
print(slice_func(a, b))
|
||||
|
||||
|
||||
def test_context():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test context in graph.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class ContextNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ContextNet, self).__init__()
|
||||
self.mode = context.get_context("mode")
|
||||
|
||||
def construct(self):
|
||||
out = 1
|
||||
if self.mode == context.GRAPH_MODE:
|
||||
out = 2
|
||||
return out
|
||||
|
||||
net = ContextNet()
|
||||
out = net()
|
||||
print(out)
|
||||
|
|
|
@ -31,8 +31,8 @@ def test_np_array_1():
|
|||
a = np.array([1, 2, 3])
|
||||
return Tensor(a)
|
||||
res = np_array_1()
|
||||
expect_res = Tensor(np.array([1, 2, 3]))
|
||||
assert np.all(res.asnumpy() == expect_res.asnumpy())
|
||||
expect_res = np.array([1, 2, 3])
|
||||
assert np.all(res.asnumpy() == expect_res)
|
||||
|
||||
|
||||
def test_np_array_2():
|
||||
|
@ -46,8 +46,8 @@ def test_np_array_2():
|
|||
a = np.array([[1, 2], [3, 4]])
|
||||
return Tensor(a)
|
||||
res = np_array_2()
|
||||
expect_res = Tensor(np.array([[1, 2], [3, 4]]))
|
||||
assert np.all(res.asnumpy() == expect_res.asnumpy())
|
||||
expect_res = np.array([[1, 2], [3, 4]])
|
||||
assert np.all(res.asnumpy() == expect_res)
|
||||
|
||||
|
||||
def test_np_array_3():
|
||||
|
@ -61,8 +61,8 @@ def test_np_array_3():
|
|||
a = np.array([1, 2, 3, 4, 5], ndmin=2)
|
||||
return Tensor(a)
|
||||
res = np_array_3()
|
||||
expect_res = Tensor(np.array([[1, 2, 3, 4, 5]]))
|
||||
assert np.all(res.asnumpy() == expect_res.asnumpy())
|
||||
expect_res = np.array([[1, 2, 3, 4, 5]])
|
||||
assert np.all(res.asnumpy() == expect_res)
|
||||
|
||||
|
||||
def test_np_array_4():
|
||||
|
@ -76,7 +76,8 @@ def test_np_array_4():
|
|||
a = np.array([1, 2, 3], dtype=complex)
|
||||
return Tensor(a)
|
||||
res = np_array_4()
|
||||
assert np.all(res.asnumpy() == Tensor(np.array([1+0j, 2+0j, 3+0j])).asnumpy())
|
||||
expect_res = np.array([1+0j, 2+0j, 3+0j])
|
||||
assert np.all(res.asnumpy() == expect_res)
|
||||
|
||||
|
||||
def test_np_dtype_1():
|
||||
|
@ -90,7 +91,8 @@ def test_np_dtype_1():
|
|||
t = np.dtype(np.int32)
|
||||
return Tensor(np.array([1, 2, 3], dtype=t))
|
||||
res = np_dtype_1()
|
||||
assert np.all(res.asnumpy() == Tensor(np.array([1, 2, 3], dtype=np.int32)).asnumpy())
|
||||
expect_res = np.array([1, 2, 3], dtype=np.int32)
|
||||
assert np.all(res.asnumpy() == expect_res)
|
||||
|
||||
|
||||
def test_np_dtype_2():
|
||||
|
@ -104,7 +106,8 @@ def test_np_dtype_2():
|
|||
t = np.dtype('i4')
|
||||
return Tensor(np.array([1, 2, 3], dtype=t))
|
||||
res = np_dtype_2()
|
||||
assert np.all(res.asnumpy() == Tensor(np.array([1, 2, 3], dtype=np.int32)).asnumpy())
|
||||
expect_res = np.array([1, 2, 3], dtype=np.int32)
|
||||
assert np.all(res.asnumpy() == expect_res)
|
||||
|
||||
|
||||
def test_np_array_ndim():
|
||||
|
@ -180,8 +183,8 @@ def test_np_empty_zeros_ones():
|
|||
z = np.ones(x.shape, dtype=np.int)
|
||||
return Tensor(y + z)
|
||||
res = np_empty_zeros_ones()
|
||||
except_res = Tensor(np.ones([3, 2], dtype=np.int))
|
||||
assert np.all(res.asnumpy() == except_res.asnumpy())
|
||||
except_res = np.ones([3, 2], dtype=np.int)
|
||||
assert np.all(res.asnumpy() == except_res)
|
||||
|
||||
|
||||
def test_np_asarray_list():
|
||||
|
@ -196,8 +199,8 @@ def test_np_asarray_list():
|
|||
y = np.asarray(x)
|
||||
return Tensor(y)
|
||||
res = np_asarray_list()
|
||||
except_res = Tensor(np.asarray([1, 2, 3]))
|
||||
assert np.all(res.asnumpy() == except_res.asnumpy())
|
||||
except_res = np.asarray([1, 2, 3])
|
||||
assert np.all(res.asnumpy() == except_res)
|
||||
|
||||
|
||||
def test_np_asarray_tuple():
|
||||
|
@ -212,8 +215,8 @@ def test_np_asarray_tuple():
|
|||
y = np.asarray(x)
|
||||
return Tensor(y)
|
||||
res = np_asarray_tuple()
|
||||
except_res = Tensor(np.asarray((1, 2, 3)))
|
||||
assert np.all(res.asnumpy() == except_res.asnumpy())
|
||||
except_res = np.asarray((1, 2, 3))
|
||||
assert np.all(res.asnumpy() == except_res)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
|
@ -245,8 +248,8 @@ def test_np_fromiter():
|
|||
x = np.fromiter(it, dtype=float)
|
||||
return Tensor(x)
|
||||
res = np_fromiter()
|
||||
except_res = Tensor(np.asarray([0., 1., 2., 3., 4.]))
|
||||
assert np.all(res.asnumpy() == except_res.asnumpy())
|
||||
except_res = np.asarray([0., 1., 2., 3., 4.])
|
||||
assert np.all(res.asnumpy() == except_res)
|
||||
|
||||
|
||||
def test_np_arange():
|
||||
|
@ -261,8 +264,8 @@ def test_np_arange():
|
|||
y = np.arange(10, 20, 2)
|
||||
return Tensor(x + y)
|
||||
res = np_arange()
|
||||
except_res = Tensor(np.asarray([10., 13., 16., 19., 22.]))
|
||||
assert np.all(res.asnumpy() == except_res.asnumpy())
|
||||
except_res = np.asarray([10., 13., 16., 19., 22.])
|
||||
assert np.all(res.asnumpy() == except_res)
|
||||
|
||||
|
||||
def test_np_logspace():
|
||||
|
@ -276,8 +279,8 @@ def test_np_logspace():
|
|||
a = np.logspace(0, 9, 10, base=2)
|
||||
return Tensor(a)
|
||||
res = np_logspace()
|
||||
except_res = Tensor(np.array([1., 2., 4., 8., 16., 32., 64., 128., 256., 512.]))
|
||||
assert np.all(res.asnumpy() == except_res.asnumpy())
|
||||
except_res = np.array([1., 2., 4., 8., 16., 32., 64., 128., 256., 512.])
|
||||
assert np.all(res.asnumpy() == except_res)
|
||||
|
||||
|
||||
def test_np_array_shape():
|
||||
|
@ -352,6 +355,21 @@ def test_np_binop():
|
|||
assert np.all(res.asnumpy() == np.array([5, 7, 9]))
|
||||
|
||||
|
||||
def test_np_binop_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test numpy's binary operation in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def np_binop():
|
||||
a = np.int_(1)
|
||||
b = 4 + a
|
||||
return Tensor(b)
|
||||
res = np_binop()
|
||||
assert res == 5
|
||||
|
||||
|
||||
def test_np_compare():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -368,6 +386,22 @@ def test_np_compare():
|
|||
assert np.all(res.asnumpy() == np.array([True, False, False]))
|
||||
|
||||
|
||||
def test_np_compare_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test numpy's compare operation in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def np_compare():
|
||||
a = 1
|
||||
b = np.int_(3)
|
||||
c = a < b
|
||||
return Tensor(c)
|
||||
res = np_compare()
|
||||
assert res
|
||||
|
||||
|
||||
def test_np_bool_and():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -381,7 +415,7 @@ def test_np_bool_and():
|
|||
c = a and b
|
||||
return Tensor(c)
|
||||
res = np_bool_and()
|
||||
assert not res.asnumpy()
|
||||
assert not res
|
||||
|
||||
|
||||
def test_np_bool_or():
|
||||
|
@ -397,7 +431,21 @@ def test_np_bool_or():
|
|||
c = a or b
|
||||
return Tensor(c)
|
||||
res = np_bool_or()
|
||||
assert res.asnumpy()
|
||||
assert res
|
||||
|
||||
|
||||
def test_np_bool_or_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test OR operation in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def np_bool_or():
|
||||
out = 0 or np.bool_(True)
|
||||
return Tensor(out)
|
||||
res = np_bool_or()
|
||||
assert res
|
||||
|
||||
|
||||
def test_np_bool_not():
|
||||
|
@ -412,4 +460,90 @@ def test_np_bool_not():
|
|||
b = not a
|
||||
return Tensor(b)
|
||||
res = np_bool_not()
|
||||
assert not res.asnumpy()
|
||||
assert not res
|
||||
|
||||
|
||||
def test_np_augassign():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test augassign method in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def np_augassign():
|
||||
value_add = np.array([1, 2, 3])
|
||||
value_add += np.array([4, 5, 6])
|
||||
value_sub = np.array([5, 5, 5])
|
||||
value_sub -= np.array([1, 2, 3])
|
||||
value_mul = np.int_(2)
|
||||
value_mul *= np.int_(3)
|
||||
value_div = np.int_(10)
|
||||
value_div /= np.int_(5)
|
||||
value_floordiv = np.int_(5)
|
||||
value_floordiv //= np.int_(2)
|
||||
return Tensor(value_add), Tensor(value_sub), Tensor(value_mul), Tensor(value_div), Tensor(value_floordiv)
|
||||
|
||||
out_add, out_sub, out_mul, out_div, out_floordiv = np_augassign()
|
||||
assert np.all(out_add.asnumpy() == np.array([5, 7, 9]))
|
||||
assert np.all(out_sub.asnumpy() == np.array([4, 3, 2]))
|
||||
assert out_mul == 6
|
||||
assert out_div == 2
|
||||
assert out_floordiv == 2
|
||||
|
||||
|
||||
def test_np_augassign_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test augassign method in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def np_augassign():
|
||||
value_mod = np.int_(5)
|
||||
value_mod %= np.int_(2)
|
||||
value_pow = np.int_(3)
|
||||
value_pow **= np.int_(2)
|
||||
value_lshift = np.int_(4)
|
||||
value_lshift <<= 1
|
||||
value_rshift = np.int_(4)
|
||||
value_rshift >>= 1
|
||||
value_bitxor = np.int_(0)
|
||||
value_bitxor ^= 1
|
||||
return Tensor(value_mod), Tensor(value_pow), Tensor(value_lshift), Tensor(value_rshift), Tensor(value_bitxor)
|
||||
|
||||
out_mod, out_pow, out_lshift, out_rshift, out_bitxor = np_augassign()
|
||||
assert out_mod == 1
|
||||
assert out_pow == 9
|
||||
assert out_lshift == 8
|
||||
assert out_rshift == 2
|
||||
assert out_bitxor == 1
|
||||
|
||||
|
||||
def test_np_subscript():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test subscript method in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def np_subscript():
|
||||
a = np.array([1, 2, 3])
|
||||
b = a[np.int32(1)]
|
||||
return Tensor(b)
|
||||
res = np_subscript()
|
||||
assert res == 2
|
||||
|
||||
|
||||
def test_np_slice():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test slice method in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def np_slice():
|
||||
a = np.arange(10)
|
||||
b = a[1:5]
|
||||
return Tensor(b)
|
||||
res = np_slice()
|
||||
assert np.all(res.asnumpy() == np.array([1, 2, 3, 4]))
|
||||
|
|
|
@ -197,6 +197,35 @@ def test_fallback_tensor_astype():
|
|||
print(foo())
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_fallback_tensor_asnumpy():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test Tensor.asnumpy() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
me_x = Tensor(np.arange(0, 6).reshape(2, 3))
|
||||
np_x = me_x.asnumpy()
|
||||
return Tensor(np_x)
|
||||
print(foo())
|
||||
|
||||
|
||||
def test_fallback_tensor_from_numpy():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test Tensor.from_numpy() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
np_x = np.array([1, 2])
|
||||
me_x = Tensor.from_numpy(np_x)
|
||||
return me_x
|
||||
print(foo())
|
||||
|
||||
|
||||
# EvalCNode: This may be not defined, or it can't be a operator.
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_np_tensor_add():
|
||||
|
@ -231,9 +260,6 @@ def test_fallback_tensor_binop():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
class BinOpNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(BinOpNet, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
np_array = np.array(9)
|
||||
res = Tensor(np_array) + Tensor(np_array)
|
||||
|
@ -250,9 +276,6 @@ def test_fallback_tensor_compare():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
class CompareNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CompareNet, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
np_array_1 = np.array(1)
|
||||
np_array_2 = np.array(2)
|
||||
|
@ -270,9 +293,6 @@ def test_fallback_tensor_not():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
class NotNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NotNet, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
np_array_1 = np.array(True, dtype=np.bool_)
|
||||
res = not Tensor(np_array_1)
|
||||
|
@ -290,9 +310,6 @@ def test_fallback_tensor_and():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
class AndNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(AndNet, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
np_array_1 = np.array(True, dtype=np.bool_)
|
||||
np_array_2 = np.array(False, dtype=np.bool_)
|
||||
|
@ -311,9 +328,6 @@ def test_fallback_tensor_or():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
class OrNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(OrNet, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
np_array_1 = np.array(True, dtype=np.bool_)
|
||||
np_array_2 = np.array(False, dtype=np.bool_)
|
||||
|
@ -332,9 +346,6 @@ def test_fallback_tensor_augassign():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
class OrNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(OrNet, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
np_array_1 = np.array(1)
|
||||
np_array_2 = np.array(2)
|
||||
|
@ -354,9 +365,6 @@ def test_fallback_tensor_subscript():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
class SubScriptNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SubScriptNet, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
np_array_1 = np.array([1, 2, 3, 4, 5])
|
||||
np_array_2 = np.array(2)
|
||||
|
@ -375,9 +383,6 @@ def test_fallback_tensor_if():
|
|||
Expectation: No exception.
|
||||
"""
|
||||
class IfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(IfNet, self).__init__()
|
||||
|
||||
def construct(self):
|
||||
np_array_1 = np.array(1)
|
||||
if Tensor(np_array_1):
|
||||
|
@ -387,3 +392,17 @@ def test_fallback_tensor_if():
|
|||
net = IfNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
|
||||
|
||||
def test_fallback_tensor_slice():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: support interpreted nodes in slice.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
array = np.arange(10)
|
||||
out = Tensor(array)[1:5]
|
||||
return out
|
||||
print(foo())
|
||||
|
|
Loading…
Reference in New Issue