forked from mindspore-Ecosystem/mindspore
!40687 Add testcase for PyNative tasksink
Merge pull request !40687 from caifubi/master-pynative-tasksink-testcase
This commit is contained in:
commit
23f857c49f
|
@ -20,6 +20,7 @@ from multiprocessing import Process, Queue
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore.common import JitConfig
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms as C
|
||||
|
@ -415,7 +416,12 @@ class GradWrap(Cell):
|
|||
return grad_by_list(self.network, weights)(x, label)
|
||||
|
||||
|
||||
def test_pynative_resnet50():
|
||||
def test_pynative_resnet50(task_sink=False):
|
||||
"""
|
||||
Feature: PyNative ResNet50
|
||||
Description: test PyNative ResNet50
|
||||
Expectation: success
|
||||
"""
|
||||
batch_size = 32
|
||||
num_classes = 10
|
||||
loss_scale = 128
|
||||
|
@ -431,6 +437,8 @@ def test_pynative_resnet50():
|
|||
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
loss_scale = FixedLossScaleManager(loss_scale=loss_scale, drop_overflow_update=False)
|
||||
if task_sink:
|
||||
net.set_jit_config(JitConfig(jit_level="O2"))
|
||||
model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale, metrics={'acc'},
|
||||
amp_level="O2", keep_batchnorm_fp32=False)
|
||||
|
||||
|
@ -453,7 +461,7 @@ def test_pynative_resnet50_with_mpi():
|
|||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=False, device_num=8)
|
||||
|
||||
good_steps = test_pynative_resnet50()
|
||||
good_steps = test_pynative_resnet50(True)
|
||||
assert good_steps > 10
|
||||
|
||||
|
||||
|
@ -469,7 +477,7 @@ def test_pynative_resnet50_with_env(queue, device_id, device_num):
|
|||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=False,
|
||||
device_num=device_num)
|
||||
|
||||
good_steps = test_pynative_resnet50()
|
||||
good_steps = test_pynative_resnet50(False)
|
||||
queue.put(good_steps)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue