fallback parse Tensor, different from NumPy.

This commit is contained in:
huangbingjian 2021-12-11 13:01:24 +08:00
parent 21176d6c51
commit db4567989a
10 changed files with 220 additions and 78 deletions

View File

@ -69,6 +69,12 @@ AST_SUB_TYPE_STARRED = 8 # ast.Starred
AST_SUB_TYPE_ATTRIBUTE = 9 # ast.Attribute
AST_SUB_TYPE_UNKNOWN = 0xFF # unknown
# Syntax support
SYNTAX_SUPPORTED = 0 # supported syntax
SYNTAX_UNSUPPORTED_INTERNAL_TYPE = 1 # unsupported internal type
SYNTAX_UNSUPPORTED_EXTERNAL_TYPE = 2 # unsupported external type
SYNTAX_UNSUPPORTED_NAMESPACE = 3 # unsupported namespace
# Process expr statement white list
# add as needed, eg: "clear", "extend", "insert", "remove", "reverse"
parse_expr_statement_white_list = (
@ -77,6 +83,14 @@ parse_expr_statement_white_list = (
_builtin_function_or_method_type = type(abs)
_unsupported_python_builtin_type = (
list, tuple, set, dict, slice, bool, int, float, str, complex, reversed,
)
_unsupported_internal_type = (
Tensor,
)
def create_slice_obj(start, end, step):
"""Create slice object"""
@ -613,6 +627,20 @@ class Parser:
logger.debug(f"'{value}' unsupported: {unsupported}.")
return unsupported
def is_unsupported_python_builtin_type(self, value):
"""To check if not supported builtin type"""
unsupported = value in _unsupported_python_builtin_type
logger.debug(f"value: '{value}', unsupported builtin type: {unsupported}.")
return unsupported
def is_unsupported_internal_type(self, value):
"""To check if not supported internal type, such as Tensor"""
for item in _unsupported_internal_type:
if value == item:
logger.debug(f"Found unsupported internal type: '{value}'.")
return True
return False
def get_namespace_symbol(self, var: str):
"""Get symbol type and namespace and symbol."""
if var in self.closure_namespace:
@ -629,13 +657,6 @@ class Parser:
error_info = f"The name '{var}' is not defined in function '{self.function_name}'."
return None, error_info
def is_unsupported_builtin_type(self, value_type):
"""To check if not supported builtin type"""
unsupported_builtin_type = (list, tuple, set, dict, slice, bool, int, float, str, complex, reversed)
is_unsupported = value_type in unsupported_builtin_type
logger.debug(f"value_type: {value_type}, unsupported builtin type: {is_unsupported}.")
return is_unsupported
def is_supported_namespace_module(self, value):
"""To check if the module is allowed to support."""
# Check `mindspore` namespace.
@ -656,11 +677,6 @@ class Parser:
logger.debug(f"Found 'mindspore.numpy' namespace.")
return True
# Check `Tensor` namespace.
if value == Tensor:
logger.debug(f"Not support '{name}'.")
return False
# Check `builtins` namespace.
if hasattr(value, '__module__'): # Not types.ModuleType
mod = value.__module__
@ -713,14 +729,13 @@ class Parser:
value_str = value.__name__ if hasattr(value, '__name__') else str(value)
logger.debug(f"value: {type(value)}, '{value_str}', hasattr(__name__): {hasattr(value, '__name__')}.")
# To check if allowed to support.
if self.is_unsupported_namespace(value):
return self.global_namespace, var, value
if self.is_unsupported_builtin_type(value):
return self.global_namespace, var, value
if not self.is_supported_namespace_module(value): # Check if support including instance of types.ModuleType
return self.global_namespace, var, value
supported = True
return self.global_namespace, var, value, supported
if self.is_unsupported_internal_type(value):
return self.global_namespace, var, value, SYNTAX_UNSUPPORTED_INTERNAL_TYPE
if self.is_unsupported_python_builtin_type(value):
return self.global_namespace, var, value, SYNTAX_UNSUPPORTED_EXTERNAL_TYPE
if self.is_unsupported_namespace(value) or not self.is_supported_namespace_module(value):
return self.global_namespace, var, value, SYNTAX_UNSUPPORTED_NAMESPACE
return self.global_namespace, var, value, SYNTAX_SUPPORTED
error_info = f"The name '{var}' is not defined, or not supported in graph mode."
logger.debug(f"error_info: {error_info}")

View File

@ -220,13 +220,13 @@ AnfNodePtr FunctionBlock::HandleNamespaceInfo(const py::tuple &info) {
AnfNodePtr FunctionBlock::HandleBuiltinNamespaceInfo(const py::tuple &info) {
constexpr size_t closure_info_size = 2;
constexpr size_t unsupported_info_size = 3;
constexpr size_t supported_info_size = 4;
constexpr size_t namespace_info_size = 4;
constexpr size_t namespace_index = 0;
constexpr size_t symbol_index = 1;
constexpr size_t value_index = 2;
if (info.size() < closure_info_size || info.size() > supported_info_size) {
MS_EXCEPTION(NameError) << "namespace info size should be 2, 3 or 4, but got " << info.size();
constexpr size_t flag_index = 3;
if (info.size() != closure_info_size && info.size() != namespace_info_size) {
MS_EXCEPTION(NameError) << "namespace info size should be 2 or 4, but got " << info.size();
}
// Handle closure namespace info.
@ -240,8 +240,12 @@ AnfNodePtr FunctionBlock::HandleBuiltinNamespaceInfo(const py::tuple &info) {
// Handle global namespace info.
auto resolved_node = GetResolveNode(info);
if (info.size() == unsupported_info_size) {
auto syntax_support = info[flag_index].cast<int32_t>();
if (syntax_support != SYNTAX_SUPPORTED) {
resolved_node->set_interpret(true);
if (syntax_support == SYNTAX_UNSUPPORTED_INTERNAL_TYPE) {
resolved_node->set_interpret_internal_type(true);
}
}
SymbolPtr symbol = std::make_shared<Symbol>(info[symbol_index].cast<std::string>());
py::object py_obj = info[value_index];
@ -309,6 +313,7 @@ AnfNodePtr FunctionBlock::MakeInterpret(const std::string &script_text, const An
auto node = func_graph_->NewCNodeInOrder(
{NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});
node->set_interpreted_node(orig_node);
node->set_interpret_internal_type(orig_node->interpret_internal_type());
return node;
}

View File

@ -560,6 +560,7 @@ FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::ob
AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast BinOP";
MS_EXCEPTION_IF_NULL(block);
py::object left = python_adapter::GetPyObjAttr(node, "left");
py::object right = python_adapter::GetPyObjAttr(node, "right");
py::object op = python_adapter::GetPyObjAttr(node, "op");
@ -575,11 +576,12 @@ AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &n
}
right_node = HandleInterpret(block, right_node, right);
// Resolve the op
MS_EXCEPTION_IF_NULL(block);
AnfNodePtr op_node = block->MakeResolveAstOp(op);
// Create apply node
MS_EXCEPTION_IF_NULL(block->func_graph());
return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
auto new_node = block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
UpdateInterpretForUserNode(left_node, new_node);
return new_node;
}
AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &node) {
@ -733,9 +735,7 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
bool need_unpack = need_unpack_args || need_unpack_keywords;
auto call_cnode = GenerateAnfNodeForCall(block, call_function_node, packed_arguments, group_arguments, need_unpack);
if (call_function_node->interpret()) {
call_cnode->set_interpret(true);
}
UpdateInterpretForUserNode(call_function_node, call_cnode);
return call_cnode;
}
@ -878,9 +878,7 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
// Create the apply node
auto attr_cnode = block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
if (value_node->interpret() || IsPrimitiveCNode(value_node, prim::kPrimPyInterpret)) {
attr_cnode->set_interpret(true);
}
UpdateInterpretForUserNode(value_node, attr_cnode);
return attr_cnode;
}
@ -908,7 +906,9 @@ AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object
MS_EXCEPTION_IF_NULL(block);
AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]);
return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
auto new_node = block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
UpdateInterpretForUserNode(left_node, new_node);
return new_node;
}
AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode) {
@ -964,6 +964,7 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p
std::vector<AnfNodePtr> call_graph_nodes{switch_app};
auto switch_app_call = block_fg->NewCNodeInOrder(std::move(call_graph_nodes));
UpdateInterpretForUserNode(test_node, switch_app_call);
return switch_app_call;
}
}
@ -1101,7 +1102,7 @@ AnfNodePtr Parser::ParseIndex(const FunctionBlockPtr &block, const py::object &n
return ParseExprNode(block, value_node);
}
// Process a UnaryOp, +a, -b
// Process a UnaryOp, +a, -b
AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast UnaryOp";
py::object op = python_adapter::GetPyObjAttr(node, "op");
@ -1113,7 +1114,9 @@ AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object
py::object operand = python_adapter::GetPyObjAttr(node, "operand");
AnfNodePtr operand_node = ParseExprNode(block, operand);
operand_node = HandleInterpret(block, operand_node, operand);
return block->func_graph()->NewCNodeInOrder({op_node, operand_node});
auto new_node = block->func_graph()->NewCNodeInOrder({op_node, operand_node});
UpdateInterpretForUserNode(operand_node, new_node);
return new_node;
}
// Process a dict ast node expression
@ -1179,6 +1182,7 @@ FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py:
WriteAssignVars(block, target_object, augassign_app);
return block;
}
// Process global declaration such as 'global x';
FunctionBlockPtr Parser::ParseGlobal(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast Global";
@ -1891,6 +1895,17 @@ void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &ta
}
}
void Parser::UpdateInterpretForUserNode(const AnfNodePtr &node, const AnfNodePtr &user_node) {
// Do not handle user node with internal type such as Tensor.abs().
bool interpret_without_internal = IsPrimitiveCNode(node, prim::kPrimPyInterpret) && !node->interpret_internal_type();
if (node->interpret() || interpret_without_internal) {
user_node->set_interpret(true);
if (node->interpret_internal_type()) {
user_node->set_interpret_internal_type(true);
}
}
}
bool Parser::IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph) {
// Check global parameters.

View File

@ -190,13 +190,15 @@ class Parser {
// Check if script_text is in global/local params.
bool IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph);
// Set the interpret flag for the node calling the interpret node.
void UpdateInterpretForUserNode(const AnfNodePtr &node, const AnfNodePtr &user_node);
// Check if the node need interpreting.
AnfNodePtr HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
const py::object &value_object);
// Generate argument nodes for ast function node
// Generate argument nodes for ast function node
void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &function_node);
// Generate argument default value for ast function node
// Generate argument default value for ast function node
void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node);
// Parse ast function node
FunctionBlockPtr ParseDefFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);

View File

@ -166,6 +166,14 @@ enum ClassInstanceTypeDef {
CLASS_INSTANCE_TYPE_INVALID = 0xFF
};
// Define syntax support.
enum SyntaxSupportDef {
SYNTAX_SUPPORTED = 0, // supported syntax
SYNTAX_UNSUPPORTED_INTERNAL_TYPE = 1, // unsupported internal type
SYNTAX_UNSUPPORTED_EXTERNAL_TYPE = 2, // unsupported external type
SYNTAX_UNSUPPORTED_NAMESPACE = 3 // unsupported namespace
};
// Convert python object to ValuePtr.
bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false, const TypePtr &dtype = nullptr);

View File

@ -119,6 +119,7 @@ class MS_CORE_API AnfNode : public Base {
scope_(ScopeManager::GetInstance().GetCurrentScope()),
kernel_info_(nullptr),
interpret_(false),
interpret_internal_type_(false),
interpreted_node_(nullptr) {}
/// \brief Constructor.
@ -345,6 +346,18 @@ class MS_CORE_API AnfNode : public Base {
/// \param[in] interpret Boolean.
void set_interpret(const bool &interpret) { interpret_ = interpret; }
/// \brief Check if there is an interpret node related to the unsupported internal type.
///
/// \return True if there is an interpret node related to the unsupported internal type, otherwise false.
bool interpret_internal_type() { return interpret_internal_type_; }
/// \brief Whether there is an interpret node with unsupported internal type.
///
/// \param[in] interpret_internal_type Boolean.
void set_interpret_internal_type(const bool &interpret_internal_type) {
interpret_internal_type_ = interpret_internal_type;
}
/// \brief Get interpreted node.
///
/// \return Interpreted node.
@ -369,6 +382,7 @@ class MS_CORE_API AnfNode : public Base {
KernelInfoDevicePtr kernel_info_;
UserData user_data_;
bool interpret_;
bool interpret_internal_type_;
AnfNodePtr interpreted_node_;
};

View File

@ -221,7 +221,11 @@ def test_np_fallback_func_tensor_index():
assert output == output_expect
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_calculate():
"""
Feature: Fallback feature.
@ -235,3 +239,21 @@ def test_np_calculate():
z = Tensor(y)
return z
assert np.all(np_calculate().asnumpy() == np.array([1, 1, 0, 0, 1]))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_tensor_array_astype():
"""
Feature: JIT Fallback
Description: Test Tensor(array) with astype() in graph mode.
Expectation: No exception.
"""
@ms_function
def foo():
me_x = Tensor([1.1, -2.1]).astype("float32")
return me_x
print(foo())

View File

@ -716,7 +716,11 @@ def test_np_sort():
assert np.all(out_where.asnumpy() == np.array([4]))
@pytest.mark.skip(reason='Not support graph fallback feature yet')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_np_extract():
"""
Feature: JIT Fallback

View File

@ -334,3 +334,82 @@ def test_np_array_imag():
return Tensor(a.imag)
res = np_array_imag()
print("res:", res)
def test_np_binop():
"""
Feature: JIT Fallback
Description: Test numpy's binary operation in graph mode.
Expectation: No exception.
"""
@ms_function
def np_binop():
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
c = a + b
return Tensor(c)
res = np_binop()
assert np.all(res.asnumpy() == np.array([5, 7, 9]))
def test_np_compare():
"""
Feature: JIT Fallback
Description: Test numpy's compare operation in graph mode.
Expectation: No exception.
"""
@ms_function
def np_compare():
a = np.array([1, 2, 3])
b = np.array([0, 2, 4])
c = a > b
return Tensor(c)
res = np_compare()
assert np.all(res.asnumpy() == np.array([True, False, False]))
def test_np_bool_and():
"""
Feature: JIT Fallback
Description: Test AND operation in graph mode.
Expectation: No exception.
"""
@ms_function
def np_bool_and():
a = np.bool_(True)
b = np.bool_(False)
c = a and b
return Tensor(c)
res = np_bool_and()
assert not res.asnumpy()
def test_np_bool_or():
"""
Feature: JIT Fallback
Description: Test OR operation in graph mode.
Expectation: No exception.
"""
@ms_function
def np_bool_or():
a = np.bool_(True)
b = np.bool_(False)
c = a or b
return Tensor(c)
res = np_bool_or()
assert res.asnumpy()
def test_np_bool_not():
"""
Feature: JIT Fallback
Description: Test NOT operation in graph mode.
Expectation: No exception.
"""
@ms_function
def np_bool_not():
a = np.bool_(True)
b = not a
return Tensor(b)
res = np_bool_not()
assert not res.asnumpy()

View File

@ -9,7 +9,7 @@ from mindspore.common.initializer import One
context.set_context(mode=context.GRAPH_MODE)
def test_tensor():
def test_fallback_tensor():
"""
Feature: JIT Fallback
Description: Test Tensor() in graph mode.
@ -22,7 +22,7 @@ def test_tensor():
print(foo())
def test_tensor_bool():
def test_fallback_tensor_bool():
"""
Feature: JIT Fallback
Description: Test Tensor(bool) in graph mode.
@ -35,7 +35,7 @@ def test_tensor_bool():
print(foo())
def test_tensor_array():
def test_fallback_tensor_array():
"""
Feature: JIT Fallback
Description: Test Tensor(array) in graph mode.
@ -48,7 +48,7 @@ def test_tensor_array():
print(foo())
def test_tensor_with_mstype():
def test_fallback_tensor_with_mstype():
"""
Feature: JIT Fallback
Description: Test Tensor() with mstype in graph mode.
@ -61,7 +61,7 @@ def test_tensor_with_mstype():
print(foo())
def test_tensor_array_with_mstype():
def test_fallback_tensor_array_with_mstype():
"""
Feature: JIT Fallback
Description: Test Tensor(array) with mstype in graph mode.
@ -74,21 +74,7 @@ def test_tensor_array_with_mstype():
print(foo())
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
def test_tensor_array_astype():
"""
Feature: JIT Fallback
Description: Test Tensor(array) with astype() in graph mode.
Expectation: No exception.
"""
@ms_function
def foo():
me_x = Tensor([1.1, -2.1]).astype("float32")
return me_x
print(foo())
def test_tensor_with_numpy():
def test_fallback_tensor_with_numpy():
"""
Feature: JIT Fallback
Description: Test Tensor() with numpy in graph mode.
@ -101,7 +87,7 @@ def test_tensor_with_numpy():
print(foo())
def test_tensor_with_init():
def test_fallback_tensor_with_init():
"""
Feature: JIT Fallback
Description: Test Tensor() with init in graph mode.
@ -114,7 +100,7 @@ def test_tensor_with_init():
print(foo())
def test_tensor_reshape():
def test_fallback_tensor_reshape():
"""
Feature: JIT Fallback
Description: Test Tensor() with reshape() in graph mode.
@ -127,8 +113,7 @@ def test_tensor_reshape():
print(foo())
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
def test_tensor_abs():
def test_fallback_tensor_abs():
"""
Feature: JIT Fallback
Description: Test Tensor.abs() in graph mode.
@ -136,14 +121,13 @@ def test_tensor_abs():
"""
@ms_function
def foo():
a = Tensor([1.1, -2.1]).astype("float32")
a = Tensor([1.1, -2.1])
out = a.abs()
return out
print(foo())
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
def test_tensor_all():
def test_fallback_tensor_all():
"""
Feature: JIT Fallback
Description: Test Tensor.all() in graph mode.
@ -157,8 +141,7 @@ def test_tensor_all():
print(foo())
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
def test_tensor_any():
def test_fallback_tensor_any():
"""
Feature: JIT Fallback
Description: Test Tensor.any() in graph mode.
@ -172,8 +155,7 @@ def test_tensor_any():
print(foo())
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
def test_tensor_argmax():
def test_fallback_tensor_argmax():
"""
Feature: JIT Fallback
Description: Test Tensor.argmax() in graph mode.
@ -187,8 +169,7 @@ def test_tensor_argmax():
print(foo())
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
def test_tensor_argmin():
def test_fallback_tensor_argmin():
"""
Feature: JIT Fallback
Description: Test Tensor.argmin() in graph mode.
@ -202,8 +183,7 @@ def test_tensor_argmin():
print(foo())
@pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
def test_tensor_astype():
def test_fallback_tensor_astype():
"""
Feature: JIT Fallback
Description: Test Tensor.astype() in graph mode.
@ -244,7 +224,7 @@ def test_np_tensor_add():
assert tensor_list[-1] == 11
def test_binop_new_tensor():
def test_fallback_tensor_binop():
"""
Feature: Fallback feature
Description: support binop's interpreted nodes.
@ -303,7 +283,6 @@ def test_fallback_tensor_not():
print("res:", res)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_fallback_tensor_and():
"""
Feature: Fallback feature
@ -325,7 +304,6 @@ def test_fallback_tensor_and():
print("res:", res)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_fallback_tensor_or():
"""
Feature: Fallback feature