forked from mindspore-Ecosystem/mindspore
!40687 Add testcase for PyNative tasksink
Merge pull request !40687 from caifubi/master-pynative-tasksink-testcase
This commit is contained in:
commit
23f857c49f
|
@ -20,6 +20,7 @@ from multiprocessing import Process, Queue
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from mindspore.common import JitConfig
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
import mindspore.dataset as ds
|
import mindspore.dataset as ds
|
||||||
import mindspore.dataset.transforms as C
|
import mindspore.dataset.transforms as C
|
||||||
|
@ -415,7 +416,12 @@ class GradWrap(Cell):
|
||||||
return grad_by_list(self.network, weights)(x, label)
|
return grad_by_list(self.network, weights)(x, label)
|
||||||
|
|
||||||
|
|
||||||
def test_pynative_resnet50():
|
def test_pynative_resnet50(task_sink=False):
|
||||||
|
"""
|
||||||
|
Feature: PyNative ResNet50
|
||||||
|
Description: test PyNative ResNet50
|
||||||
|
Expectation: success
|
||||||
|
"""
|
||||||
batch_size = 32
|
batch_size = 32
|
||||||
num_classes = 10
|
num_classes = 10
|
||||||
loss_scale = 128
|
loss_scale = 128
|
||||||
|
@ -431,6 +437,8 @@ def test_pynative_resnet50():
|
||||||
|
|
||||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||||
loss_scale = FixedLossScaleManager(loss_scale=loss_scale, drop_overflow_update=False)
|
loss_scale = FixedLossScaleManager(loss_scale=loss_scale, drop_overflow_update=False)
|
||||||
|
if task_sink:
|
||||||
|
net.set_jit_config(JitConfig(jit_level="O2"))
|
||||||
model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale, metrics={'acc'},
|
model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale, metrics={'acc'},
|
||||||
amp_level="O2", keep_batchnorm_fp32=False)
|
amp_level="O2", keep_batchnorm_fp32=False)
|
||||||
|
|
||||||
|
@ -453,7 +461,7 @@ def test_pynative_resnet50_with_mpi():
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=False, device_num=8)
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=False, device_num=8)
|
||||||
|
|
||||||
good_steps = test_pynative_resnet50()
|
good_steps = test_pynative_resnet50(True)
|
||||||
assert good_steps > 10
|
assert good_steps > 10
|
||||||
|
|
||||||
|
|
||||||
|
@ -469,7 +477,7 @@ def test_pynative_resnet50_with_env(queue, device_id, device_num):
|
||||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=False,
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=False,
|
||||||
device_num=device_num)
|
device_num=device_num)
|
||||||
|
|
||||||
good_steps = test_pynative_resnet50()
|
good_steps = test_pynative_resnet50(False)
|
||||||
queue.put(good_steps)
|
queue.put(good_steps)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue