forked from mindspore-Ecosystem/mindspore
add fallback testcases
This commit is contained in:
parent
7259b74092
commit
6275cdcf7a
|
@ -219,7 +219,7 @@ AnfNodePtr FunctionBlock::HandleNamespaceInfo(const py::tuple &namespace_info) {
|
||||||
if (unsupported) {
|
if (unsupported) {
|
||||||
resolved_node->set_interpret(true);
|
resolved_node->set_interpret(true);
|
||||||
AddGlobalPyParam(symbol->name(), py_obj);
|
AddGlobalPyParam(symbol->name(), py_obj);
|
||||||
MS_LOG(INFO) << "[" << func_graph()->ToString() << "] Added global python symblol: {" << symbol->name() << " : "
|
MS_LOG(INFO) << "[" << func_graph()->ToString() << "] Added global python symbol: {" << symbol->name() << " : "
|
||||||
<< py::str(py_obj) << "}";
|
<< py::str(py_obj) << "}";
|
||||||
}
|
}
|
||||||
return resolved_node;
|
return resolved_node;
|
||||||
|
@ -268,7 +268,7 @@ AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
|
||||||
|
|
||||||
AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) {
|
AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) {
|
||||||
MS_LOG(DEBUG) << "MakeResolve for " << (name_space ? (std::string)py::str(name_space->obj()) : "null namespace")
|
MS_LOG(DEBUG) << "MakeResolve for " << (name_space ? (std::string)py::str(name_space->obj()) : "null namespace")
|
||||||
<< " , " << (resolve_symbol ? (std::string)resolve_symbol->symbol() : "null resoleve symbol.");
|
<< " , " << (resolve_symbol ? (std::string)resolve_symbol->symbol() : "null resolve symbol.");
|
||||||
ValueNodePtr module_node = NewValueNode(name_space);
|
ValueNodePtr module_node = NewValueNode(name_space);
|
||||||
ValueNodePtr symbol_node = NewValueNode(resolve_symbol);
|
ValueNodePtr symbol_node = NewValueNode(resolve_symbol);
|
||||||
auto node = func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node});
|
auto node = func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node});
|
||||||
|
@ -343,7 +343,7 @@ AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const Parame
|
||||||
// If all arguments of a φ-function are the same value s or the φfunction itself,
|
// If all arguments of a φ-function are the same value s or the φfunction itself,
|
||||||
// then we remove the φ-function and let all users directly uses. We call such a
|
// then we remove the φ-function and let all users directly uses. We call such a
|
||||||
// φ-function obviously unnecessary.
|
// φ-function obviously unnecessary.
|
||||||
// When we removed a φ-function p, then we recursively try to apply this simplification
|
// When we removed a φ-function p, then we recursively try to apply this simplification
|
||||||
// rule with all (former) users of p, because they may have become obviously unnecessary
|
// rule with all (former) users of p, because they may have become obviously unnecessary
|
||||||
// due to the removal of p
|
// due to the removal of p
|
||||||
// <Quote>
|
// <Quote>
|
||||||
|
|
|
@ -185,6 +185,7 @@ def np_fallback_func_tensor_index(x):
|
||||||
return me_x[x]
|
return me_x[x]
|
||||||
|
|
||||||
|
|
||||||
|
# NameError: name 'array_x' is not defined.
|
||||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||||
def test_np_fallback_func_tensor_index():
|
def test_np_fallback_func_tensor_index():
|
||||||
"""
|
"""
|
||||||
|
@ -216,6 +217,7 @@ class ControlNet(nn.Cell):
|
||||||
return self.inner_function_2(a, b)
|
return self.inner_function_2(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
# NameError: name 'mstype' is not defined.
|
||||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||||
def test_fallback_control_sink_tensor():
|
def test_fallback_control_sink_tensor():
|
||||||
"""
|
"""
|
||||||
|
@ -228,3 +230,54 @@ def test_fallback_control_sink_tensor():
|
||||||
output = net(x)
|
output = net(x)
|
||||||
output_expect = Tensor(9, mstype.int32)
|
output_expect = Tensor(9, mstype.int32)
|
||||||
assert output == output_expect
|
assert output == output_expect
|
||||||
|
|
||||||
|
|
||||||
|
# NameError: name 'mytype' is not defined
|
||||||
|
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||||
|
def test_np_tensor_list():
|
||||||
|
"""
|
||||||
|
Feature: Fallback feature
|
||||||
|
Description: support Basic method of Tensor list.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
@ms_function
|
||||||
|
def np_tensor_list():
|
||||||
|
a = Tensor(np.array(4), mstype.int32)
|
||||||
|
b = Tensor(np.array(5), mstype.int32)
|
||||||
|
c = Tensor(np.array(6), mstype.int32)
|
||||||
|
tensor_list = [a, b]
|
||||||
|
for tensor in tensor_list:
|
||||||
|
print(tensor)
|
||||||
|
tensor_list.append(tensor_list[-1] + c)
|
||||||
|
return tensor_list
|
||||||
|
|
||||||
|
tensor_list = np_tensor_list()
|
||||||
|
print("tensor_list:", tensor_list)
|
||||||
|
assert len(tensor_list) == 3
|
||||||
|
|
||||||
|
|
||||||
|
# EvalCNode: This may be not defined, or it can't be a operator.
|
||||||
|
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||||
|
def test_np_tensor_add():
|
||||||
|
"""
|
||||||
|
Feature: Fallback feature
|
||||||
|
Description: support Tensor add.
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
@ms_function
|
||||||
|
def np_tensor_add():
|
||||||
|
a = Tensor(np.array(4))
|
||||||
|
b = Tensor(np.array(5))
|
||||||
|
tensor_list = [a, b]
|
||||||
|
for tensor in tensor_list:
|
||||||
|
print(tensor)
|
||||||
|
x = 6
|
||||||
|
np_x = np.array(x)
|
||||||
|
c = Tensor(np_x)
|
||||||
|
d = tensor_list[-1] + c
|
||||||
|
tensor_list.append(d)
|
||||||
|
return tensor_list
|
||||||
|
|
||||||
|
tensor_list = np_tensor_add()
|
||||||
|
print("tensor_list:", tensor_list)
|
||||||
|
assert tensor_list[-1] == 11
|
||||||
|
|
Loading…
Reference in New Issue