fix code warning && remove save_graphs use in st/ut

This commit is contained in:
huanghui 2021-12-23 11:18:33 +08:00
parent 39b3fc2922
commit 74ca50e652
14 changed files with 36 additions and 14 deletions

View File

@ -46,6 +46,7 @@ using mindspore::tensor::TensorPy;
namespace mindspore {
std::string GetKernelNodeName(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
std::string kernel_name = anf_node->fullname_with_scope();
if (kernel_name.empty()) {
kernel_name = anf_node->ToString();
@ -57,6 +58,7 @@ std::string GetKernelNodeName(const AnfNodePtr &anf_node) {
// ============================================= MindSpore IR Exporter =============================================
std::string AnfExporter::GetNodeType(const AnfNodePtr &nd) {
MS_EXCEPTION_IF_NULL(nd);
ValuePtr tensor_value = nullptr;
auto abstract = nd->abstract();
if (abstract != nullptr && abstract->isa<abstract::AbstractTensor>()) {
@ -479,6 +481,9 @@ void AnfExporter::OutputStatementComment(std::ofstream &ofs, const CNodePtr &nod
void AnfExporter::OutputCNodeText(std::ofstream &ofs, const CNodePtr &cnode, const FuncGraphPtr &func_graph, int *idx,
std::map<AnfNodePtr, int> *const apply_map) {
if (cnode == nullptr || func_graph == nullptr || idx == nullptr || apply_map == nullptr) {
return;
}
auto &inputs = cnode->inputs();
std::string op_text = GetAnfNodeText(func_graph, inputs[0], *apply_map);
std::string fv_text = (cnode->func_graph() != func_graph) ? ("$(" + cnode->func_graph()->ToString() + "):") : "";

View File

@ -39,6 +39,7 @@ void AnalysisSchedule::Schedule() {
}
void AnalysisSchedule::Yield(const AsyncInferTask *async_infer_task) {
MS_EXCEPTION_IF_NULL(async_infer_task);
{
std::lock_guard<std::mutex> activeLock(activate_thread_lock_);
if (async_infer_task->ready() == 0) {

View File

@ -36,7 +36,7 @@ class Tensor(Tensor_):
Tensor is a data structure that stores an n-dimensional array.
Args:
input_data (Union[Tensor, float, int, bool, tuple, list, numpy.ndarray]): The data to be stroed. It can be
input_data (Union[Tensor, float, int, bool, tuple, list, numpy.ndarray]): The data to be stored. It can be
another Tensor, Python number or NumPy ndarray. Default: None.
dtype (:class:`mindspore.dtype`): Used to indicate the data type of the output Tensor. The argument should
be defined in `mindspore.dtype`. If it is None, the data type of the output Tensor will be the same

View File

@ -21,6 +21,7 @@ from mindspore.ops import composite as C
from mindspore import context
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from tests.security_utils import security_off_wrap
context.set_context(mode=context.GRAPH_MODE)
@ -78,9 +79,11 @@ def test_forward():
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@security_off_wrap
def test_backward():
# Graph Mode
context.set_context(mode=context.GRAPH_MODE)
context.set_context(save_graphs=True)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
graph_forward_net = ForwardNet(max_cycles=10)

View File

@ -19,6 +19,7 @@ from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore import context
from tests.security_utils import security_off_wrap
context.set_context(mode=context.GRAPH_MODE)
@ -69,8 +70,10 @@ def test_forward():
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@security_off_wrap
def test_backward():
context.set_context(mode=context.GRAPH_MODE)
context.set_context(save_graphs=True)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)

View File

@ -20,6 +20,7 @@ from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore import context
from tests.security_utils import security_off_wrap
context.set_context(mode=context.GRAPH_MODE)
@ -74,8 +75,10 @@ def test_forward():
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@security_off_wrap
def test_backward():
context.set_context(mode=context.GRAPH_MODE)
context.set_context(save_graphs=True)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)

View File

@ -20,6 +20,7 @@ from mindspore.common.parameter import Parameter
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from tests.security_utils import security_off_wrap
grad_all = C.GradOperation(get_all=True)
@ -75,6 +76,7 @@ def test_for_in_while_01():
@pytest.mark.skip(reason="not supported for in while")
@security_off_wrap
def test_for_in_while_02():
class ForInWhileNet(nn.Cell):
def __init__(self):
@ -108,6 +110,7 @@ def test_for_in_while_02():
# graph mode
context.set_context(mode=context.GRAPH_MODE)
context.set_context(save_graphs=True)
for_in_while_net = ForInWhileNet()
net = GradNet(for_in_while_net)
graph_forward_res = for_in_while_net(x)

View File

@ -22,6 +22,7 @@ from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore import context
from mindspore.common.parameter import Parameter
from tests.security_utils import security_off_wrap
context.set_context(mode=context.GRAPH_MODE)
@ -132,11 +133,13 @@ class BackwardNetNoAssign(nn.Cell):
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@security_off_wrap
def test_backward_no_assign():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
# Graph Mode
context.set_context(mode=context.GRAPH_MODE)
context.set_context(save_graphs=True)
graph_forward_net = ForwardNetNoAssign(max_cycles=3)
graph_backward_net = BackwardNetNoAssign(graph_forward_net)
graph_mode_grads = graph_backward_net(x, y)

View File

@ -20,6 +20,7 @@ from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore import context
from tests.security_utils import security_off_wrap
context.set_context(mode=context.GRAPH_MODE)
@ -75,8 +76,10 @@ def test_forward():
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@security_off_wrap
def test_backward():
context.set_context(mode=context.GRAPH_MODE)
context.set_context(save_graphs=True)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)

View File

@ -20,6 +20,7 @@ from mindspore import nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore import context
from tests.security_utils import security_off_wrap
context.set_context(mode=context.GRAPH_MODE)
@ -63,8 +64,10 @@ class BackwardNet(nn.Cell):
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@security_off_wrap
def test_forward():
context.set_context(mode=context.GRAPH_MODE)
context.set_context(save_graphs=True)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = ForwardNet(max_cycles=3)

View File

@ -21,8 +21,9 @@ from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore import context
from mindspore.common.parameter import Parameter
from tests.security_utils import security_off_wrap
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
context.set_context(mode=context.GRAPH_MODE)
grad_all = C.GradOperation(get_all=True)
@ -125,7 +126,7 @@ def test_while_break_forward():
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_while_break_backward():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
forward_net = WhileBreakForwardNet(max_cycles=10)
@ -173,7 +174,7 @@ def test_if_after_if_in_while_break_forward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
# Graph Mode
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
context.set_context(mode=context.GRAPH_MODE)
graph_forward_net = IfAfterIfInWhileBreakForwardNet(max_cycles=10)
graph_mode_out = graph_forward_net(x, y)
assert graph_mode_out == Tensor(np.array(16), mstype.int32)
@ -184,11 +185,12 @@ def test_if_after_if_in_while_break_forward():
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@security_off_wrap
def test_if_after_if_in_while_break_backward():
x = Tensor(np.array(1), mstype.int32)
y = Tensor(np.array(3), mstype.int32)
# Graph Mode
context.set_context(mode=context.GRAPH_MODE)
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
graph_forward_net = IfAfterIfInWhileBreakForwardNet(max_cycles=10)
graph_backward_net = Grad(graph_forward_net)
graph_mode_grads = graph_backward_net(x, y)
@ -409,8 +411,6 @@ def test_for_in_for_break():
# raise a endless loop exception.
@pytest.mark.skip(reason="Infer raise a endless loop exception")
def test_while_true_break():
context.set_context(save_graphs=True)
class WhileTrueBreakNet(nn.Cell):
def __init__(self, t):
super(WhileTrueBreakNet, self).__init__()
@ -441,8 +441,6 @@ def test_while_true_break():
# stuck in vm backend
@pytest.mark.skip(reason="Stuck in vm backend")
def test_continue_stuck_in_vm():
context.set_context(save_graphs=True)
class NetWork(nn.Cell):
def __init__(self, t):
super().__init__()

View File

@ -22,7 +22,7 @@ from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, save_graphs=True, save_graphs_path="graph_paths")
context.set_context(mode=context.GRAPH_MODE)
class ArgumentNum(nn.Cell):

View File

@ -15,9 +15,6 @@
""" test_filter """
from mindspore.nn import Cell
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
def is_odd(x):

View File

@ -28,7 +28,7 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData
context.set_context(mode=context.PYNATIVE_MODE, save_graphs=True)
context.set_context(mode=context.PYNATIVE_MODE)
class Net(nn.Cell):
"""Net definition"""