!40687 Add testcase for PyNative tasksink

Merge pull request !40687 from caifubi/master-pynative-tasksink-testcase
This commit is contained in:
i-robot 2022-08-23 10:42:30 +00:00 committed by Gitee
commit 23f857c49f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 11 additions and 3 deletions

View File

@ -20,6 +20,7 @@ from multiprocessing import Process, Queue
import numpy as np
import pytest
from mindspore.common import JitConfig
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms as C
@ -415,7 +416,12 @@ class GradWrap(Cell):
return grad_by_list(self.network, weights)(x, label)
def test_pynative_resnet50():
def test_pynative_resnet50(task_sink=False):
"""
Feature: PyNative ResNet50
Description: test PyNative ResNet50
Expectation: success
"""
batch_size = 32
num_classes = 10
loss_scale = 128
@ -431,6 +437,8 @@ def test_pynative_resnet50():
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss_scale = FixedLossScaleManager(loss_scale=loss_scale, drop_overflow_update=False)
if task_sink:
net.set_jit_config(JitConfig(jit_level="O2"))
model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False)
@ -453,7 +461,7 @@ def test_pynative_resnet50_with_mpi():
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=False, device_num=8)
good_steps = test_pynative_resnet50()
good_steps = test_pynative_resnet50(True)
assert good_steps > 10
@ -469,7 +477,7 @@ def test_pynative_resnet50_with_env(queue, device_id, device_num):
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=False,
device_num=device_num)
good_steps = test_pynative_resnet50()
good_steps = test_pynative_resnet50(False)
queue.put(good_steps)