forked from mindspore-Ecosystem/mindspore
fallback parse Tensor, different from NumPy.
This commit is contained in:
parent
21176d6c51
commit
db4567989a
|
@ -69,6 +69,12 @@ AST_SUB_TYPE_STARRED = 8 # ast.Starred
|
|||
AST_SUB_TYPE_ATTRIBUTE = 9 # ast.Attribute
|
||||
AST_SUB_TYPE_UNKNOWN = 0xFF # unknown
|
||||
|
||||
# Syntax support
|
||||
SYNTAX_SUPPORTED = 0 # supported syntax
|
||||
SYNTAX_UNSUPPORTED_INTERNAL_TYPE = 1 # unsupported internal type
|
||||
SYNTAX_UNSUPPORTED_EXTERNAL_TYPE = 2 # unsupported external type
|
||||
SYNTAX_UNSUPPORTED_NAMESPACE = 3 # unsupported namespace
|
||||
|
||||
# Process expr statement white list
|
||||
# add as needed, eg: "clear", "extend", "insert", "remove", "reverse"
|
||||
parse_expr_statement_white_list = (
|
||||
|
@ -77,6 +83,14 @@ parse_expr_statement_white_list = (
|
|||
|
||||
_builtin_function_or_method_type = type(abs)
|
||||
|
||||
_unsupported_python_builtin_type = (
|
||||
list, tuple, set, dict, slice, bool, int, float, str, complex, reversed,
|
||||
)
|
||||
|
||||
_unsupported_internal_type = (
|
||||
Tensor,
|
||||
)
|
||||
|
||||
|
||||
def create_slice_obj(start, end, step):
|
||||
"""Create slice object"""
|
||||
|
@ -613,6 +627,20 @@ class Parser:
|
|||
logger.debug(f"'{value}' unsupported: {unsupported}.")
|
||||
return unsupported
|
||||
|
||||
def is_unsupported_python_builtin_type(self, value):
|
||||
"""To check if not supported builtin type"""
|
||||
unsupported = value in _unsupported_python_builtin_type
|
||||
logger.debug(f"value: '{value}', unsupported builtin type: {unsupported}.")
|
||||
return unsupported
|
||||
|
||||
def is_unsupported_internal_type(self, value):
|
||||
"""To check if not supported internal type, such as Tensor"""
|
||||
for item in _unsupported_internal_type:
|
||||
if value == item:
|
||||
logger.debug(f"Found unsupported internal type: '{value}'.")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_namespace_symbol(self, var: str):
|
||||
"""Get symbol type and namespace and symbol."""
|
||||
if var in self.closure_namespace:
|
||||
|
@ -629,13 +657,6 @@ class Parser:
|
|||
error_info = f"The name '{var}' is not defined in function '{self.function_name}'."
|
||||
return None, error_info
|
||||
|
||||
def is_unsupported_builtin_type(self, value_type):
|
||||
"""To check if not supported builtin type"""
|
||||
unsupported_builtin_type = (list, tuple, set, dict, slice, bool, int, float, str, complex, reversed)
|
||||
is_unsupported = value_type in unsupported_builtin_type
|
||||
logger.debug(f"value_type: {value_type}, unsupported builtin type: {is_unsupported}.")
|
||||
return is_unsupported
|
||||
|
||||
def is_supported_namespace_module(self, value):
|
||||
"""To check if the module is allowed to support."""
|
||||
# Check `mindspore` namespace.
|
||||
|
@ -656,11 +677,6 @@ class Parser:
|
|||
logger.debug(f"Found 'mindspore.numpy' namespace.")
|
||||
return True
|
||||
|
||||
# Check `Tensor` namespace.
|
||||
if value == Tensor:
|
||||
logger.debug(f"Not support '{name}'.")
|
||||
return False
|
||||
|
||||
# Check `builtins` namespace.
|
||||
if hasattr(value, '__module__'): # Not types.ModuleType
|
||||
mod = value.__module__
|
||||
|
@ -713,14 +729,13 @@ class Parser:
|
|||
value_str = value.__name__ if hasattr(value, '__name__') else str(value)
|
||||
logger.debug(f"value: {type(value)}, '{value_str}', hasattr(__name__): {hasattr(value, '__name__')}.")
|
||||
# To check if allowed to support.
|
||||
if self.is_unsupported_namespace(value):
|
||||
return self.global_namespace, var, value
|
||||
if self.is_unsupported_builtin_type(value):
|
||||
return self.global_namespace, var, value
|
||||
if not self.is_supported_namespace_module(value): # Check if support including instance of types.ModuleType
|
||||
return self.global_namespace, var, value
|
||||
supported = True
|
||||
return self.global_namespace, var, value, supported
|
||||
if self.is_unsupported_internal_type(value):
|
||||
return self.global_namespace, var, value, SYNTAX_UNSUPPORTED_INTERNAL_TYPE
|
||||
if self.is_unsupported_python_builtin_type(value):
|
||||
return self.global_namespace, var, value, SYNTAX_UNSUPPORTED_EXTERNAL_TYPE
|
||||
if self.is_unsupported_namespace(value) or not self.is_supported_namespace_module(value):
|
||||
return self.global_namespace, var, value, SYNTAX_UNSUPPORTED_NAMESPACE
|
||||
return self.global_namespace, var, value, SYNTAX_SUPPORTED
|
||||
|
||||
error_info = f"The name '{var}' is not defined, or not supported in graph mode."
|
||||
logger.debug(f"error_info: {error_info}")
|
||||
|
|
|
@ -220,13 +220,13 @@ AnfNodePtr FunctionBlock::HandleNamespaceInfo(const py::tuple &info) {
|
|||
|
||||
AnfNodePtr FunctionBlock::HandleBuiltinNamespaceInfo(const py::tuple &info) {
|
||||
constexpr size_t closure_info_size = 2;
|
||||
constexpr size_t unsupported_info_size = 3;
|
||||
constexpr size_t supported_info_size = 4;
|
||||
constexpr size_t namespace_info_size = 4;
|
||||
constexpr size_t namespace_index = 0;
|
||||
constexpr size_t symbol_index = 1;
|
||||
constexpr size_t value_index = 2;
|
||||
if (info.size() < closure_info_size || info.size() > supported_info_size) {
|
||||
MS_EXCEPTION(NameError) << "namespace info size should be 2, 3 or 4, but got " << info.size();
|
||||
constexpr size_t flag_index = 3;
|
||||
if (info.size() != closure_info_size && info.size() != namespace_info_size) {
|
||||
MS_EXCEPTION(NameError) << "namespace info size should be 2 or 4, but got " << info.size();
|
||||
}
|
||||
|
||||
// Handle closure namespace info.
|
||||
|
@ -240,8 +240,12 @@ AnfNodePtr FunctionBlock::HandleBuiltinNamespaceInfo(const py::tuple &info) {
|
|||
|
||||
// Handle global namespace info.
|
||||
auto resolved_node = GetResolveNode(info);
|
||||
if (info.size() == unsupported_info_size) {
|
||||
auto syntax_support = info[flag_index].cast<int32_t>();
|
||||
if (syntax_support != SYNTAX_SUPPORTED) {
|
||||
resolved_node->set_interpret(true);
|
||||
if (syntax_support == SYNTAX_UNSUPPORTED_INTERNAL_TYPE) {
|
||||
resolved_node->set_interpret_internal_type(true);
|
||||
}
|
||||
}
|
||||
SymbolPtr symbol = std::make_shared<Symbol>(info[symbol_index].cast<std::string>());
|
||||
py::object py_obj = info[value_index];
|
||||
|
@ -309,6 +313,7 @@ AnfNodePtr FunctionBlock::MakeInterpret(const std::string &script_text, const An
|
|||
auto node = func_graph_->NewCNodeInOrder(
|
||||
{NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});
|
||||
node->set_interpreted_node(orig_node);
|
||||
node->set_interpret_internal_type(orig_node->interpret_internal_type());
|
||||
return node;
|
||||
}
|
||||
|
||||
|
|
|
@ -560,6 +560,7 @@ FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::ob
|
|||
AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast BinOP";
|
||||
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
py::object left = python_adapter::GetPyObjAttr(node, "left");
|
||||
py::object right = python_adapter::GetPyObjAttr(node, "right");
|
||||
py::object op = python_adapter::GetPyObjAttr(node, "op");
|
||||
|
@ -575,11 +576,12 @@ AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &n
|
|||
}
|
||||
right_node = HandleInterpret(block, right_node, right);
|
||||
// Resolve the op
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr op_node = block->MakeResolveAstOp(op);
|
||||
// Create apply node
|
||||
MS_EXCEPTION_IF_NULL(block->func_graph());
|
||||
return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
|
||||
auto new_node = block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
|
||||
UpdateInterpretForUserNode(left_node, new_node);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &node) {
|
||||
|
@ -733,9 +735,7 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
|
|||
bool need_unpack = need_unpack_args || need_unpack_keywords;
|
||||
|
||||
auto call_cnode = GenerateAnfNodeForCall(block, call_function_node, packed_arguments, group_arguments, need_unpack);
|
||||
if (call_function_node->interpret()) {
|
||||
call_cnode->set_interpret(true);
|
||||
}
|
||||
UpdateInterpretForUserNode(call_function_node, call_cnode);
|
||||
return call_cnode;
|
||||
}
|
||||
|
||||
|
@ -878,9 +878,7 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
|
|||
|
||||
// Create the apply node
|
||||
auto attr_cnode = block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
|
||||
if (value_node->interpret() || IsPrimitiveCNode(value_node, prim::kPrimPyInterpret)) {
|
||||
attr_cnode->set_interpret(true);
|
||||
}
|
||||
UpdateInterpretForUserNode(value_node, attr_cnode);
|
||||
return attr_cnode;
|
||||
}
|
||||
|
||||
|
@ -908,7 +906,9 @@ AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object
|
|||
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]);
|
||||
return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
|
||||
auto new_node = block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
|
||||
UpdateInterpretForUserNode(left_node, new_node);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode) {
|
||||
|
@ -964,6 +964,7 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p
|
|||
|
||||
std::vector<AnfNodePtr> call_graph_nodes{switch_app};
|
||||
auto switch_app_call = block_fg->NewCNodeInOrder(std::move(call_graph_nodes));
|
||||
UpdateInterpretForUserNode(test_node, switch_app_call);
|
||||
return switch_app_call;
|
||||
}
|
||||
}
|
||||
|
@ -1101,7 +1102,7 @@ AnfNodePtr Parser::ParseIndex(const FunctionBlockPtr &block, const py::object &n
|
|||
return ParseExprNode(block, value_node);
|
||||
}
|
||||
|
||||
// Process a UnaryOp, +a, -b
|
||||
// Process a UnaryOp, +a, -b
|
||||
AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast UnaryOp";
|
||||
py::object op = python_adapter::GetPyObjAttr(node, "op");
|
||||
|
@ -1113,7 +1114,9 @@ AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object
|
|||
py::object operand = python_adapter::GetPyObjAttr(node, "operand");
|
||||
AnfNodePtr operand_node = ParseExprNode(block, operand);
|
||||
operand_node = HandleInterpret(block, operand_node, operand);
|
||||
return block->func_graph()->NewCNodeInOrder({op_node, operand_node});
|
||||
auto new_node = block->func_graph()->NewCNodeInOrder({op_node, operand_node});
|
||||
UpdateInterpretForUserNode(operand_node, new_node);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
// Process a dict ast node expression
|
||||
|
@ -1179,6 +1182,7 @@ FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py:
|
|||
WriteAssignVars(block, target_object, augassign_app);
|
||||
return block;
|
||||
}
|
||||
|
||||
// Process global declaration such as 'global x';
|
||||
FunctionBlockPtr Parser::ParseGlobal(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast Global";
|
||||
|
@ -1891,6 +1895,17 @@ void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &ta
|
|||
}
|
||||
}
|
||||
|
||||
void Parser::UpdateInterpretForUserNode(const AnfNodePtr &node, const AnfNodePtr &user_node) {
|
||||
// Do not handle user node with internal type such as Tensor.abs().
|
||||
bool interpret_without_internal = IsPrimitiveCNode(node, prim::kPrimPyInterpret) && !node->interpret_internal_type();
|
||||
if (node->interpret() || interpret_without_internal) {
|
||||
user_node->set_interpret(true);
|
||||
if (node->interpret_internal_type()) {
|
||||
user_node->set_interpret_internal_type(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool Parser::IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
|
||||
const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph) {
|
||||
// Check global parameters.
|
||||
|
|
|
@ -190,13 +190,15 @@ class Parser {
|
|||
// Check if script_text is in global/local params.
|
||||
bool IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
|
||||
const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph);
|
||||
// Set the interpret flag for the node calling the interpret node.
|
||||
void UpdateInterpretForUserNode(const AnfNodePtr &node, const AnfNodePtr &user_node);
|
||||
// Check if the node need interpreting.
|
||||
AnfNodePtr HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
|
||||
const py::object &value_object);
|
||||
|
||||
// Generate argument nodes for ast function node
|
||||
// Generate argument nodes for ast function node
|
||||
void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &function_node);
|
||||
// Generate argument default value for ast function node
|
||||
// Generate argument default value for ast function node
|
||||
void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node);
|
||||
// Parse ast function node
|
||||
FunctionBlockPtr ParseDefFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
|
||||
|
|
|
@ -166,6 +166,14 @@ enum ClassInstanceTypeDef {
|
|||
CLASS_INSTANCE_TYPE_INVALID = 0xFF
|
||||
};
|
||||
|
||||
// Define syntax support.
|
||||
enum SyntaxSupportDef {
|
||||
SYNTAX_SUPPORTED = 0, // supported syntax
|
||||
SYNTAX_UNSUPPORTED_INTERNAL_TYPE = 1, // unsupported internal type
|
||||
SYNTAX_UNSUPPORTED_EXTERNAL_TYPE = 2, // unsupported external type
|
||||
SYNTAX_UNSUPPORTED_NAMESPACE = 3 // unsupported namespace
|
||||
};
|
||||
|
||||
// Convert python object to ValuePtr.
|
||||
bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false, const TypePtr &dtype = nullptr);
|
||||
|
||||
|
|
|
@ -119,6 +119,7 @@ class MS_CORE_API AnfNode : public Base {
|
|||
scope_(ScopeManager::GetInstance().GetCurrentScope()),
|
||||
kernel_info_(nullptr),
|
||||
interpret_(false),
|
||||
interpret_internal_type_(false),
|
||||
interpreted_node_(nullptr) {}
|
||||
|
||||
/// \brief Constructor.
|
||||
|
@ -345,6 +346,18 @@ class MS_CORE_API AnfNode : public Base {
|
|||
/// \param[in] interpret Boolean.
|
||||
void set_interpret(const bool &interpret) { interpret_ = interpret; }
|
||||
|
||||
/// \brief Check if there is an interpret node related to the unsupported internal type.
|
||||
///
|
||||
/// \return True if there is an interpret node related to the unsupported internal type, otherwise false.
|
||||
bool interpret_internal_type() { return interpret_internal_type_; }
|
||||
|
||||
/// \brief Whether there is an interpret node with unsupported internal type.
|
||||
///
|
||||
/// \param[in] interpret_internal_type Boolean.
|
||||
void set_interpret_internal_type(const bool &interpret_internal_type) {
|
||||
interpret_internal_type_ = interpret_internal_type;
|
||||
}
|
||||
|
||||
/// \brief Get interpreted node.
|
||||
///
|
||||
/// \return Interpreted node.
|
||||
|
@ -369,6 +382,7 @@ class MS_CORE_API AnfNode : public Base {
|
|||
KernelInfoDevicePtr kernel_info_;
|
||||
UserData user_data_;
|
||||
bool interpret_;
|
||||
bool interpret_internal_type_;
|
||||
AnfNodePtr interpreted_node_;
|
||||
};
|
||||
|
||||
|
|
|
@ -221,7 +221,11 @@ def test_np_fallback_func_tensor_index():
|
|||
assert output == output_expect
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_calculate():
|
||||
"""
|
||||
Feature: Fallback feature.
|
||||
|
@ -235,3 +239,21 @@ def test_np_calculate():
|
|||
z = Tensor(y)
|
||||
return z
|
||||
assert np.all(np_calculate().asnumpy() == np.array([1, 1, 0, 0, 1]))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_tensor_array_astype():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test Tensor(array) with astype() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
me_x = Tensor([1.1, -2.1]).astype("float32")
|
||||
return me_x
|
||||
print(foo())
|
||||
|
|
|
@ -716,7 +716,11 @@ def test_np_sort():
|
|||
assert np.all(out_where.asnumpy() == np.array([4]))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_np_extract():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
|
|
@ -334,3 +334,82 @@ def test_np_array_imag():
|
|||
return Tensor(a.imag)
|
||||
res = np_array_imag()
|
||||
print("res:", res)
|
||||
|
||||
|
||||
def test_np_binop():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test numpy's binary operation in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def np_binop():
|
||||
a = np.array([1, 2, 3])
|
||||
b = np.array([4, 5, 6])
|
||||
c = a + b
|
||||
return Tensor(c)
|
||||
res = np_binop()
|
||||
assert np.all(res.asnumpy() == np.array([5, 7, 9]))
|
||||
|
||||
|
||||
def test_np_compare():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test numpy's compare operation in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def np_compare():
|
||||
a = np.array([1, 2, 3])
|
||||
b = np.array([0, 2, 4])
|
||||
c = a > b
|
||||
return Tensor(c)
|
||||
res = np_compare()
|
||||
assert np.all(res.asnumpy() == np.array([True, False, False]))
|
||||
|
||||
|
||||
def test_np_bool_and():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test AND operation in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def np_bool_and():
|
||||
a = np.bool_(True)
|
||||
b = np.bool_(False)
|
||||
c = a and b
|
||||
return Tensor(c)
|
||||
res = np_bool_and()
|
||||
assert not res.asnumpy()
|
||||
|
||||
|
||||
def test_np_bool_or():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test OR operation in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def np_bool_or():
|
||||
a = np.bool_(True)
|
||||
b = np.bool_(False)
|
||||
c = a or b
|
||||
return Tensor(c)
|
||||
res = np_bool_or()
|
||||
assert res.asnumpy()
|
||||
|
||||
|
||||
def test_np_bool_not():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test NOT operation in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def np_bool_not():
|
||||
a = np.bool_(True)
|
||||
b = not a
|
||||
return Tensor(b)
|
||||
res = np_bool_not()
|
||||
assert not res.asnumpy()
|
||||
|
|
|
@ -9,7 +9,7 @@ from mindspore.common.initializer import One
|
|||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_tensor():
|
||||
def test_fallback_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test Tensor() in graph mode.
|
||||
|
@ -22,7 +22,7 @@ def test_tensor():
|
|||
print(foo())
|
||||
|
||||
|
||||
def test_tensor_bool():
|
||||
def test_fallback_tensor_bool():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test Tensor(bool) in graph mode.
|
||||
|
@ -35,7 +35,7 @@ def test_tensor_bool():
|
|||
print(foo())
|
||||
|
||||
|
||||
def test_tensor_array():
|
||||
def test_fallback_tensor_array():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test Tensor(array) in graph mode.
|
||||
|
@ -48,7 +48,7 @@ def test_tensor_array():
|
|||
print(foo())
|
||||
|
||||
|
||||
def test_tensor_with_mstype():
|
||||
def test_fallback_tensor_with_mstype():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test Tensor() with mstype in graph mode.
|
||||
|
@ -61,7 +61,7 @@ def test_tensor_with_mstype():
|
|||
print(foo())
|
||||
|
||||
|
||||
def test_tensor_array_with_mstype():
|
||||
def test_fallback_tensor_array_with_mstype():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test Tensor(array) with mstype in graph mode.
|
||||
|
@ -74,21 +74,7 @@ def test_tensor_array_with_mstype():
|
|||
print(foo())
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
|
||||
def test_tensor_array_astype():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test Tensor(array) with astype() in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
me_x = Tensor([1.1, -2.1]).astype("float32")
|
||||
return me_x
|
||||
print(foo())
|
||||
|
||||
|
||||
def test_tensor_with_numpy():
|
||||
def test_fallback_tensor_with_numpy():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test Tensor() with numpy in graph mode.
|
||||
|
@ -101,7 +87,7 @@ def test_tensor_with_numpy():
|
|||
print(foo())
|
||||
|
||||
|
||||
def test_tensor_with_init():
|
||||
def test_fallback_tensor_with_init():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test Tensor() with init in graph mode.
|
||||
|
@ -114,7 +100,7 @@ def test_tensor_with_init():
|
|||
print(foo())
|
||||
|
||||
|
||||
def test_tensor_reshape():
|
||||
def test_fallback_tensor_reshape():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test Tensor() with reshape() in graph mode.
|
||||
|
@ -127,8 +113,7 @@ def test_tensor_reshape():
|
|||
print(foo())
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
|
||||
def test_tensor_abs():
|
||||
def test_fallback_tensor_abs():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test Tensor.abs() in graph mode.
|
||||
|
@ -136,14 +121,13 @@ def test_tensor_abs():
|
|||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
a = Tensor([1.1, -2.1]).astype("float32")
|
||||
a = Tensor([1.1, -2.1])
|
||||
out = a.abs()
|
||||
return out
|
||||
print(foo())
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
|
||||
def test_tensor_all():
|
||||
def test_fallback_tensor_all():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test Tensor.all() in graph mode.
|
||||
|
@ -157,8 +141,7 @@ def test_tensor_all():
|
|||
print(foo())
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
|
||||
def test_tensor_any():
|
||||
def test_fallback_tensor_any():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test Tensor.any() in graph mode.
|
||||
|
@ -172,8 +155,7 @@ def test_tensor_any():
|
|||
print(foo())
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
|
||||
def test_tensor_argmax():
|
||||
def test_fallback_tensor_argmax():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test Tensor.argmax() in graph mode.
|
||||
|
@ -187,8 +169,7 @@ def test_tensor_argmax():
|
|||
print(foo())
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
|
||||
def test_tensor_argmin():
|
||||
def test_fallback_tensor_argmin():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test Tensor.argmin() in graph mode.
|
||||
|
@ -202,8 +183,7 @@ def test_tensor_argmin():
|
|||
print(foo())
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
|
||||
def test_tensor_astype():
|
||||
def test_fallback_tensor_astype():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test Tensor.astype() in graph mode.
|
||||
|
@ -244,7 +224,7 @@ def test_np_tensor_add():
|
|||
assert tensor_list[-1] == 11
|
||||
|
||||
|
||||
def test_binop_new_tensor():
|
||||
def test_fallback_tensor_binop():
|
||||
"""
|
||||
Feature: Fallback feature
|
||||
Description: support binop's interpreted nodes.
|
||||
|
@ -303,7 +283,6 @@ def test_fallback_tensor_not():
|
|||
print("res:", res)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_fallback_tensor_and():
|
||||
"""
|
||||
Feature: Fallback feature
|
||||
|
@ -325,7 +304,6 @@ def test_fallback_tensor_and():
|
|||
print("res:", res)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_fallback_tensor_or():
|
||||
"""
|
||||
Feature: Fallback feature
|
||||
|
|
Loading…
Reference in New Issue