!40687 Add testcase for PyNative tasksink

Merge pull request !40687 from caifubi/master-pynative-tasksink-testcase
This commit is contained in:
i-robot 2022-08-23 10:42:30 +00:00 committed by Gitee
commit 23f857c49f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 11 additions and 3 deletions

View File

@ -20,6 +20,7 @@ from multiprocessing import Process, Queue
import numpy as np import numpy as np
import pytest import pytest
from mindspore.common import JitConfig
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.transforms as C import mindspore.dataset.transforms as C
@ -415,7 +416,12 @@ class GradWrap(Cell):
return grad_by_list(self.network, weights)(x, label) return grad_by_list(self.network, weights)(x, label)
def test_pynative_resnet50(): def test_pynative_resnet50(task_sink=False):
"""
Feature: PyNative ResNet50
Description: test PyNative ResNet50
Expectation: success
"""
batch_size = 32 batch_size = 32
num_classes = 10 num_classes = 10
loss_scale = 128 loss_scale = 128
@ -431,6 +437,8 @@ def test_pynative_resnet50():
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss_scale = FixedLossScaleManager(loss_scale=loss_scale, drop_overflow_update=False) loss_scale = FixedLossScaleManager(loss_scale=loss_scale, drop_overflow_update=False)
if task_sink:
net.set_jit_config(JitConfig(jit_level="O2"))
model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale, metrics={'acc'}, model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False) amp_level="O2", keep_batchnorm_fp32=False)
@ -453,7 +461,7 @@ def test_pynative_resnet50_with_mpi():
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=False, device_num=8) context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=False, device_num=8)
good_steps = test_pynative_resnet50() good_steps = test_pynative_resnet50(True)
assert good_steps > 10 assert good_steps > 10
@ -469,7 +477,7 @@ def test_pynative_resnet50_with_env(queue, device_id, device_num):
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=False, context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=False,
device_num=device_num) device_num=device_num)
good_steps = test_pynative_resnet50() good_steps = test_pynative_resnet50(False)
queue.put(good_steps) queue.put(good_steps)