From 51905040dc1ad78b4117e6df21fbb83b3d774601 Mon Sep 17 00:00:00 2001 From: caifubi Date: Mon, 22 Aug 2022 14:51:00 +0800 Subject: [PATCH] Add testcase for PyNative tasksink --- tests/st/pynative/test_pynative_resnet50_ascend.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/st/pynative/test_pynative_resnet50_ascend.py b/tests/st/pynative/test_pynative_resnet50_ascend.py index 39185079308..161a3e28bd5 100644 --- a/tests/st/pynative/test_pynative_resnet50_ascend.py +++ b/tests/st/pynative/test_pynative_resnet50_ascend.py @@ -20,6 +20,7 @@ from multiprocessing import Process, Queue import numpy as np import pytest +from mindspore.common import JitConfig import mindspore.common.dtype as mstype import mindspore.dataset as ds import mindspore.dataset.transforms as C @@ -415,7 +416,12 @@ class GradWrap(Cell): return grad_by_list(self.network, weights)(x, label) -def test_pynative_resnet50(): +def test_pynative_resnet50(task_sink=False): + """ + Feature: PyNative ResNet50 + Description: test PyNative ResNet50 + Expectation: success + """ batch_size = 32 num_classes = 10 loss_scale = 128 @@ -431,6 +437,8 @@ def test_pynative_resnet50(): loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss_scale = FixedLossScaleManager(loss_scale=loss_scale, drop_overflow_update=False) + if task_sink: + net.set_jit_config(JitConfig(jit_level="O2")) model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O2", keep_batchnorm_fp32=False) @@ -453,7 +461,7 @@ def test_pynative_resnet50_with_mpi(): context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=False, device_num=8) - good_steps = test_pynative_resnet50() + good_steps = test_pynative_resnet50(True) assert good_steps > 10 @@ -469,7 +477,7 @@ def test_pynative_resnet50_with_env(queue, device_id, device_num): context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=False, device_num=device_num) - good_steps = test_pynative_resnet50() + good_steps = test_pynative_resnet50(False) queue.put(good_steps)