add fallback testcases

This commit is contained in:
Margaret_wangrui 2021-10-28 21:40:41 +08:00
parent 7259b74092
commit 6275cdcf7a
2 changed files with 56 additions and 3 deletions

View File

@ -219,7 +219,7 @@ AnfNodePtr FunctionBlock::HandleNamespaceInfo(const py::tuple &namespace_info) {
if (unsupported) {
resolved_node->set_interpret(true);
AddGlobalPyParam(symbol->name(), py_obj);
MS_LOG(INFO) << "[" << func_graph()->ToString() << "] Added global python symblol: {" << symbol->name() << " : "
MS_LOG(INFO) << "[" << func_graph()->ToString() << "] Added global python symbol: {" << symbol->name() << " : "
<< py::str(py_obj) << "}";
}
return resolved_node;
@ -268,7 +268,7 @@ AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) {
MS_LOG(DEBUG) << "MakeResolve for " << (name_space ? (std::string)py::str(name_space->obj()) : "null namespace")
<< " , " << (resolve_symbol ? (std::string)resolve_symbol->symbol() : "null resoleve symbol.");
<< " , " << (resolve_symbol ? (std::string)resolve_symbol->symbol() : "null resolve symbol.");
ValueNodePtr module_node = NewValueNode(name_space);
ValueNodePtr symbol_node = NewValueNode(resolve_symbol);
auto node = func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node});
@ -343,7 +343,7 @@ AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const Parame
// If all arguments of a φ-function are the same value s or the φfunction itself,
// then we remove the φ-function and let all users directly uses. We call such a
// φ-function obviously unnecessary.
// When we removed a φ-function p, then we recursively try to apply this simplication
// When we removed a φ-function p, then we recursively try to apply this simplification
// rule with all (former) users of p, because they may have become obviously unnecessary
// due to the removal of p
// <Quote>

View File

@ -185,6 +185,7 @@ def np_fallback_func_tensor_index(x):
return me_x[x]
# NameError: name 'array_x' is not defined.
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_fallback_func_tensor_index():
"""
@ -216,6 +217,7 @@ class ControlNet(nn.Cell):
return self.inner_function_2(a, b)
# NameError: name 'mstype' is not defined.
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_fallback_control_sink_tensor():
"""
@ -228,3 +230,54 @@ def test_fallback_control_sink_tensor():
output = net(x)
output_expect = Tensor(9, mstype.int32)
assert output == output_expect
# NameError: name 'mytype' is not defined
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_tensor_list():
"""
Feature: Fallback feature
Description: support Basic method of Tensor list.
Expectation: No exception.
"""
@ms_function
def np_tensor_list():
a = Tensor(np.array(4), mstype.int32)
b = Tensor(np.array(5), mstype.int32)
c = Tensor(np.array(6), mstype.int32)
tensor_list = [a, b]
for tensor in tensor_list:
print(tensor)
tensor_list.append(tensor_list[-1] + c)
return tensor_list
tensor_list = np_tensor_list()
print("tensor_list:", tensor_list)
assert len(tensor_list) == 3
# EvalCNode: This may be not defined, or it can't be a operator.
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_np_tensor_add():
"""
Feature: Fallback feature
Description: support Tensor add.
Expectation: No exception.
"""
@ms_function
def np_tensor_add():
a = Tensor(np.array(4))
b = Tensor(np.array(5))
tensor_list = [a, b]
for tensor in tensor_list:
print(tensor)
x = 6
np_x = np.array(x)
c = Tensor(np_x)
d = tensor_list[-1] + c
tensor_list.append(d)
return tensor_list
tensor_list = np_tensor_add()
print("tensor_list:", tensor_list)
assert tensor_list[-1] == 11