diff --git a/tests/st/networks/models/bert/bert_performance/test_bert_thor.py b/tests/st/networks/models/bert/bert_performance/test_bert_thor.py index 902154a7fea..4fc909c8505 100644 --- a/tests/st/networks/models/bert/bert_performance/test_bert_thor.py +++ b/tests/st/networks/models/bert/bert_performance/test_bert_thor.py @@ -217,6 +217,7 @@ def test_bert_thor_8p(): sum_cost_list.append(0.0) for _ in range(device_num): + assert not q.empty() output = q.get() loss_list = output['loss'] cost_list = output['cost'] diff --git a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py index 80bdf9a7ef5..09903fd479b 100644 --- a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py +++ b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py @@ -360,6 +360,7 @@ def test_resnet_and_resnet_thor_imagenet_4p(): acc = 0.0 cost = 0.0 for i in range(device_num): + assert not q.empty() output = q.get() acc += output['acc'] cost += output['cost'] diff --git a/tests/st/pynative/data_parallel/test_pynative_hccl.py b/tests/st/pynative/data_parallel/test_pynative_hccl.py index 5b8f1241aef..a26426c33c8 100644 --- a/tests/st/pynative/data_parallel/test_pynative_hccl.py +++ b/tests/st/pynative/data_parallel/test_pynative_hccl.py @@ -80,6 +80,7 @@ def test_pynative_hccl_8p(): # check result for i in range(device_num): + assert not q.empty() assert q.get() for i in range(device_num): @@ -87,7 +88,7 @@ def test_pynative_hccl_8p(): print("End training...") -@pytest.mark.level0 +@pytest.mark.level1 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.env_single @@ -110,6 +111,7 @@ def test_pynative_hccl_8pv2(): # check result for i in range(device_num): + assert not q.empty() assert q.get() for i in range(device_num): diff --git a/tests/st/pynative/data_parallel/test_pynative_hccl_allreduce.py b/tests/st/pynative/data_parallel/test_pynative_hccl_allreduce.py index 9a8391a9a35..f706bf9451e 100644 --- a/tests/st/pynative/data_parallel/test_pynative_hccl_allreduce.py +++ b/tests/st/pynative/data_parallel/test_pynative_hccl_allreduce.py @@ -90,6 +90,7 @@ def test_pynative_hccl_allreduce_8p(): # check result for i in range(device_num): expect_output = [[256, 256, 256, 256], [256, 256, 256, 256], [256, 256, 256, 256]] + assert not q.empty() output = Tensor(q.get()) assert np.allclose(output.asnumpy(), expect_output) diff --git a/tests/st/pynative/test_pynative_resnet50_ascend.py b/tests/st/pynative/test_pynative_resnet50_ascend.py index b7fd0316774..6ecec1ec432 100644 --- a/tests/st/pynative/test_pynative_resnet50_ascend.py +++ b/tests/st/pynative/test_pynative_resnet50_ascend.py @@ -13,8 +13,10 @@ # limitations under the License. # ============================================================================ +import os import time import random +from multiprocessing import Process, Queue import numpy as np import pytest @@ -32,10 +34,18 @@ from mindspore.nn import Cell from mindspore.ops import operations as P from mindspore.ops import composite as CP from mindspore.nn.optim.momentum import Momentum -from mindspore.train.callback import LossMonitor, Callback +from mindspore.train.callback import Callback from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.model import Model +from mindspore.context import ParallelMode +import mindspore.communication.management as D +MINDSPORE_HCCL_CONFIG_PATH = "/home/workspace/mindspore_config/hccl/rank_table_8p.json" + +np.random.seed(1) +os.environ['GLOG_v'] = str(2) +os.environ['ASCEND_GLOBAL_LOG_LEVEL'] = str(3) +os.environ['ASCEND_GLOBAL_EVENT_ENABLE'] = str(0) class MyTimeMonitor(Callback): def __init__(self, data_size): @@ -56,7 +66,7 @@ class MyTimeMonitor(Callback): def step_end(self, run_context): step_msseconds = (time.time() - self.step_time) * 1000 - if step_msseconds < 275: + if step_msseconds < 370: self.total = self.total + 1 print(f"step time:{step_msseconds}", flush=True) @@ -405,13 +415,7 @@ class GradWrap(Cell): return grad_by_list(self.network, weights)(x, label) -@pytest.mark.level1 -@pytest.mark.platform_arm_ascend_training -@pytest.mark.platform_x86_ascend_training -@pytest.mark.env_single def test_pynative_resnet50(): - context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") - batch_size = 32 num_classes = 10 loss_scale = 128 @@ -423,8 +427,7 @@ def test_pynative_resnet50(): # define callbacks time_cb = MyTimeMonitor(data_size=data_set.get_dataset_size()) - loss_cb = LossMonitor() - cb = [time_cb, loss_cb] + cb = [time_cb] loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss_scale = FixedLossScaleManager(loss_scale=loss_scale, drop_overflow_update=False) @@ -435,4 +438,44 @@ def test_pynative_resnet50(): model.train(1, data_set, callbacks=cb, sink_size=data_set.get_dataset_size(), dataset_sink_mode=True) - assert time_cb.good_step() > 10 + return time_cb.good_step() + + +def test_pynative_resnet50_with_env(queue, device_id, device_num): + os.system("mkdir " + str(device_id)) + os.chdir(str(device_id)) + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend", device_id=device_id) + os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH + os.environ['RANK_ID'] = str(device_id) + os.environ['RANK_SIZE'] = str(device_num) + D.init() + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=False, + device_num=device_num) + + good_steps = test_pynative_resnet50() + queue.put(good_steps) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_single +def test_pynative_resnet50_8p(): + device_num = 8 + process = [] + q = Queue() + for i in range(device_num): + device_id = i + process.append(Process(target=test_pynative_resnet50_with_env, args=(q, device_id, device_num))) + + for i in range(device_num): + process[i].start() + + for i in range(device_num): + process[i].join() + + # check result + for i in range(device_num): + assert not q.empty() + good_steps = q.get() + assert good_steps > 10