forked from mindspore-Ecosystem/mindspore
commit
311d1be98d
|
@ -2064,6 +2064,14 @@ EvalResultPtr PyExecuteEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abst
|
|||
// Call python script string.
|
||||
MS_LOG(DEBUG) << "Call script: " << script << ", args: " << args_abs_list;
|
||||
|
||||
// when return value should be none
|
||||
if (current_interpret_node->has_user_data("__py_execute_no_return_type__")) {
|
||||
AbstractBasePtr res = std::make_shared<abstract::AbstractNone>();
|
||||
res->set_value(kAnyValue);
|
||||
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
|
||||
evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
|
||||
return infer_result;
|
||||
}
|
||||
TypePtr type = kFloat64;
|
||||
if (current_interpret_node->has_user_data("__py_execute_tensor_type__")) {
|
||||
type = current_interpret_node->user_data<Type>("__py_execute_tensor_type__");
|
||||
|
@ -2739,12 +2747,6 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cur_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(cur_graph);
|
||||
if (cur_graph->is_tensor_condition_branch()) {
|
||||
MS_LOG(EXCEPTION) << "Currently only supports raise in constant scenarios. "
|
||||
<< "Tensor type data cannot exist in the conditional statement. "
|
||||
<< "Please check your conditions which raise node is located at: "
|
||||
<< trace::GetDebugInfo(node->debug_info());
|
||||
}
|
||||
if (args_abs_list.empty()) {
|
||||
// process raise
|
||||
MS_LOG(EXCEPTION) << "No active exception to reraise.";
|
||||
|
@ -2772,7 +2774,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
|
|||
const auto input = inputs[index];
|
||||
auto input_abs = args_abs_list[index - 1];
|
||||
MS_EXCEPTION_IF_NULL(input_abs);
|
||||
bool need_symbol = CheckNeedSymbol(input, input_abs);
|
||||
const bool need_symbol = CheckNeedSymbol(input, input_abs);
|
||||
if (need_symbol) {
|
||||
exception_string += "'";
|
||||
}
|
||||
|
@ -2807,6 +2809,8 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
|
|||
// Build the PyExecute node for raise error.
|
||||
const auto raise_error_node = cur_graph->NewCNodeInOrder(
|
||||
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
|
||||
auto none_type = std::make_shared<TypeNone>();
|
||||
raise_error_node->set_user_data<Type>("__py_execute_no_return_type__", none_type);
|
||||
cur_graph->ReplaceInOrder(node, raise_error_node);
|
||||
AnalysisEnginePtr eng = out_conf->engine();
|
||||
MS_EXCEPTION_IF_NULL(eng);
|
||||
|
@ -2849,7 +2853,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
|
|||
return need_symbol;
|
||||
}
|
||||
std::string GetExceptionString(const AbstractBasePtr &arg, const AnfNodePtr &input, const AnfNodePtr &node,
|
||||
const bool need_comma = false, bool need_symbol = false) {
|
||||
const bool need_comma = false, const bool need_symbol = false) {
|
||||
std::string exception_str;
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (arg->isa<abstract::AbstractSequence>()) {
|
||||
|
@ -2874,21 +2878,25 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
|
|||
}
|
||||
|
||||
std::string GetTupleOrListString(const AbstractBasePtr &arg, const AnfNodePtr &input, const AnfNodePtr &node,
|
||||
const bool need_comma, bool need_symbol = false) {
|
||||
const bool need_comma, const bool need_symbol = false) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
std::string exception_str;
|
||||
bool is_tuple = arg->isa<abstract::AbstractTuple>();
|
||||
// Process raise ValueError("str")
|
||||
auto arg_tuple = arg->cast_ptr<abstract::AbstractSequence>();
|
||||
MS_EXCEPTION_IF_NULL(arg_tuple);
|
||||
auto const &arg_tuple_elements = arg_tuple->elements();
|
||||
const auto &arg_tuple_elements = arg_tuple->elements();
|
||||
if (arg_tuple_elements.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "The arg_tuple_elements can't be empty.";
|
||||
}
|
||||
if (!input->isa<CNode>()) {
|
||||
std::string key = "__internal_error_value" + std::to_string(num_str_) + "__";
|
||||
num_str_ += 1;
|
||||
exception_str = exception_str + "{" + key + "}";
|
||||
if (need_symbol) {
|
||||
exception_str = exception_str + "'+f'{" + key + "}'+'";
|
||||
} else {
|
||||
exception_str = exception_str + key;
|
||||
}
|
||||
(void)keys_.emplace_back(NewValueNode(std::make_shared<StringImm>(key)));
|
||||
(void)values_.emplace_back(input);
|
||||
return exception_str;
|
||||
|
|
|
@ -101,6 +101,24 @@ bool CheckAbstractScalar(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool CheckIfRaise(const AnfNodePtr &node) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimPyExecute)) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto inputs = cnode->inputs();
|
||||
auto first = inputs[1];
|
||||
auto script_node = first->cast<ValueNodePtr>();
|
||||
if (script_node->value()->isa<StringImm>()) {
|
||||
auto script = GetValueNode<StringImmPtr>(script_node)->value();
|
||||
std::string raise_script = "raise_func";
|
||||
auto idx = script.find(raise_script);
|
||||
if (idx != string::npos) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ValidateAbstract(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
MS_LOG(DEBUG) << "Node to validate is invalid";
|
||||
|
@ -123,6 +141,11 @@ void ValidateAbstract(const AnfNodePtr &node) {
|
|||
MS_LOG(DEBUG) << "AbstractError in the graph: " << abstract->ToString();
|
||||
return;
|
||||
}
|
||||
if (CheckIfRaise(node)) {
|
||||
ShapeVector shp{abstract::Shape::kShapeRankAny};
|
||||
auto abs = std::make_shared<abstract::AbstractTensor>(kFloat64, std::make_shared<abstract::Shape>(shp));
|
||||
node->set_abstract(abs);
|
||||
}
|
||||
bool is_legal_abstract = abstract->isa<AbstractType>() || abstract->isa<AbstractFunction>() ||
|
||||
abstract->isa<AbstractTuple>() || abstract->isa<AbstractList>() ||
|
||||
abstract->isa<AbstractTensor>() || abstract->isa<AbstractRowTensor>() ||
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph fallback control flow."""
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore
|
||||
from mindspore import Tensor, jit, context
|
||||
|
@ -50,43 +49,6 @@ def test_for_after_while_in_if_1():
|
|||
assert res == 8
|
||||
|
||||
|
||||
@case_register.level1
|
||||
@case_register.target_gpu
|
||||
@case_register.target_ascend
|
||||
def test_for_after_while_in_if_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def func3202():
|
||||
x = Tensor([2])
|
||||
y = Tensor([2])
|
||||
if x == y and x > Tensor([0]).astype("int32"):
|
||||
y = y + Tensor([3]).astype("int32")
|
||||
z = Tensor([5], dtype=mindspore.int32)
|
||||
while y == z and z == Tensor([5], dtype=mindspore.int32):
|
||||
y = y * 1
|
||||
assert y == z, "y must equal z"
|
||||
z = z + 1
|
||||
|
||||
else:
|
||||
raise ValueError("Enter this branch, not expect!")
|
||||
|
||||
for i in range(3):
|
||||
z = Tensor([i]).astype("int32")
|
||||
x = x + z
|
||||
|
||||
return x, y
|
||||
|
||||
with pytest.raises(RuntimeError, match="Currently only supports raise in constant scenarios."):
|
||||
res_x, res_y = func3202()
|
||||
assert res_x == 5
|
||||
assert res_y == 5
|
||||
|
||||
|
||||
@case_register.level1
|
||||
@case_register.target_gpu
|
||||
@case_register.target_ascend
|
||||
|
|
|
@ -168,3 +168,128 @@ def test_raise_with_variable_joinedstr_tensor():
|
|||
print("res:", res)
|
||||
assert "The input should not be 1" in str(
|
||||
raise_info_joinedstr_tensor.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_with_variable_dic():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self):
|
||||
x = Tensor(1)
|
||||
y = Tensor(2)
|
||||
z = {"x": x, "y": y}
|
||||
raise ValueError(z)
|
||||
|
||||
with pytest.raises(RuntimeError) as raise_info_list:
|
||||
net = RaiseNet()
|
||||
res = net()
|
||||
print("res:", res)
|
||||
assert "Dictionary type is currently not supporting" in str(
|
||||
raise_info_list.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_with_variable_control_flow1():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
if x == y:
|
||||
raise RuntimeError(f"The input should not be {x}.")
|
||||
|
||||
with pytest.raises(RuntimeError) as raise_info_joinedstr_tensor:
|
||||
net = RaiseNet()
|
||||
x = Tensor(1)
|
||||
y = Tensor(1)
|
||||
res = net(x, y)
|
||||
print("res:", res)
|
||||
assert "The input should not be 1" in str(
|
||||
raise_info_joinedstr_tensor.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_with_variable_control_flow2():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
if x == y:
|
||||
raise RuntimeError(f"The input should not be {x}.")
|
||||
return x
|
||||
|
||||
with pytest.raises(RuntimeError) as raise_info_joinedstr_tensor:
|
||||
net = RaiseNet()
|
||||
x = Tensor(1)
|
||||
y = Tensor(1)
|
||||
res = net(x, y)
|
||||
print("res:", res)
|
||||
assert "The input should not be 1" in str(
|
||||
raise_info_joinedstr_tensor.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_with_variable_control_flow3():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self, x, y, z):
|
||||
if x == y:
|
||||
raise RuntimeError(f"The input should not be {x}.")
|
||||
return z
|
||||
|
||||
with pytest.raises(RuntimeError) as raise_info_joinedstr_tensor:
|
||||
net = RaiseNet()
|
||||
x = Tensor(1)
|
||||
y = Tensor(1)
|
||||
z = (x, y)
|
||||
res = net(x, y, z)
|
||||
print("res:", res)
|
||||
assert "The input should not be 1" in str(
|
||||
raise_info_joinedstr_tensor.value)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_raise_with_variable_control_flow4():
|
||||
"""
|
||||
Feature: graph raise by JIT Fallback.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self, x, y, z):
|
||||
if x == y:
|
||||
raise RuntimeError(f"The input should not be {x}.")
|
||||
return z
|
||||
|
||||
with pytest.raises(RuntimeError) as raise_info_joinedstr_tensor:
|
||||
net = RaiseNet()
|
||||
x = Tensor(1)
|
||||
y = Tensor(1)
|
||||
z = [x, y]
|
||||
res = net(x, y, z)
|
||||
print("res:", res)
|
||||
assert "The input should not be 1" in str(
|
||||
raise_info_joinedstr_tensor.value)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# ============================================================================
|
||||
""" test_assert """
|
||||
import pytest
|
||||
from mindspore import nn, context, Tensor
|
||||
from mindspore import nn, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -153,22 +153,3 @@ def test_assert7():
|
|||
net()
|
||||
assert "1 not in [2, 3, 4]" in str(excinfo.value)
|
||||
assert "assert x in [2, 3, 4]" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_assert8():
|
||||
"""
|
||||
Feature: support assert
|
||||
Description: test assert with variable in condition
|
||||
Expectation: no error
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
assert x == 1
|
||||
return x
|
||||
|
||||
net = Net()
|
||||
a = Tensor(1)
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
net(a)
|
||||
assert "Currently only supports raise in constant scenarios." in str(excinfo.value)
|
||||
|
|
|
@ -17,7 +17,6 @@ import pytest
|
|||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, context
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
@ -77,25 +76,6 @@ def test_raise_3():
|
|||
print("res:", res)
|
||||
|
||||
|
||||
def test_raise_4():
|
||||
"""
|
||||
Feature: graph raise.
|
||||
Description: Test raise.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class RaiseNet(nn.Cell):
|
||||
def construct(self, x):
|
||||
if x == 1:
|
||||
raise ValueError(f"The input should not be 1.")
|
||||
return x
|
||||
|
||||
with pytest.raises(RuntimeError, match="Currently only supports raise in constant scenarios."):
|
||||
net = RaiseNet()
|
||||
x = Tensor(9, mstype.int32)
|
||||
res = net(x)
|
||||
assert res == 9
|
||||
|
||||
|
||||
def test_raise_5():
|
||||
"""
|
||||
Feature: graph raise.
|
||||
|
|
|
@ -92,6 +92,7 @@ def test_nest_function_missing_return():
|
|||
assert "For 'make_range', the 0th input should be a int64 scalar" in str(er.value)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Case will not appear for now, but may appear in the future')
|
||||
def test_raise_in_method():
|
||||
class NetRaiseInMethod(nn.Cell):
|
||||
def construct(self, x, y, z):
|
||||
|
@ -112,31 +113,6 @@ def test_raise_in_method():
|
|||
assert "Currently only supports raise in constant scenarios." in str(er.value)
|
||||
|
||||
|
||||
def test_raise_in_nested_function():
|
||||
class NetNestRaise(nn.Cell):
|
||||
def nest_fn(self, u):
|
||||
if u > 0:
|
||||
# add not support grammar 'raise' here
|
||||
raise ValueError('Illegal case')
|
||||
return u + z + 1
|
||||
|
||||
def construct(self, x, y, z):
|
||||
if x == 1:
|
||||
return Tensor(10, mstype.int32)
|
||||
elif x == 20:
|
||||
return self.nest_fn(y)
|
||||
else:
|
||||
return y + z
|
||||
|
||||
net = NetNestRaise()
|
||||
x = Tensor(0, mstype.int32)
|
||||
y = Tensor(5, mstype.int32)
|
||||
z = Tensor(2, mstype.int32)
|
||||
with pytest.raises(RuntimeError) as er:
|
||||
net(x, y, z)
|
||||
assert "Currently only supports raise in constant scenarios." in str(er.value)
|
||||
|
||||
|
||||
def test_nest_branch_with_return():
|
||||
class NetBranchWithReturn(nn.Cell):
|
||||
def construct(self, x, y, z):
|
||||
|
|
Loading…
Reference in New Issue