Add testcase for PyNative tasksink

This commit is contained in:
caifubi 2022-08-22 14:51:00 +08:00
parent 2b6d203769
commit 51905040dc
1 changed files with 11 additions and 3 deletions

View File

@ -20,6 +20,7 @@ from multiprocessing import Process, Queue
import numpy as np
import pytest
from mindspore.common import JitConfig
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms as C
@ -415,7 +416,12 @@ class GradWrap(Cell):
return grad_by_list(self.network, weights)(x, label)
def test_pynative_resnet50():
def test_pynative_resnet50(task_sink=False):
"""
Feature: PyNative ResNet50
Description: test PyNative ResNet50
Expectation: success
"""
batch_size = 32
num_classes = 10
loss_scale = 128
@ -431,6 +437,8 @@ def test_pynative_resnet50():
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss_scale = FixedLossScaleManager(loss_scale=loss_scale, drop_overflow_update=False)
if task_sink:
net.set_jit_config(JitConfig(jit_level="O2"))
model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False)
@ -453,7 +461,7 @@ def test_pynative_resnet50_with_mpi():
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=False, device_num=8)
good_steps = test_pynative_resnet50()
good_steps = test_pynative_resnet50(True)
assert good_steps > 10
@ -469,7 +477,7 @@ def test_pynative_resnet50_with_env(queue, device_id, device_num):
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=False,
device_num=device_num)
good_steps = test_pynative_resnet50()
good_steps = test_pynative_resnet50(False)
queue.put(good_steps)