!31679 If the GPU memory is set to true, the test case fails.

Merge pull request !31679 from liuchuting/testcase
This commit is contained in:
i-robot 2022-03-23 03:26:25 +00:00 committed by Gitee
commit 3b73182032
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 7 additions and 14 deletions

View File

@ -685,15 +685,8 @@ class Profiler:
logger.warning(err.message)
logger.warning(
'\nMemory Usage is not supported on GPU currently.\n'
'Please running on Ascend if you would like to see memory analysis, '
'otherwise, this warning can be ignored.'
)
logger.warning(
'\nProfile communication is not supported on GPU currently.\n'
'Please running on Ascend if you would like to see cluster communication analysis, '
'otherwise, this warning can be ignored.'
'\nThe training and inference process does not support profiler currently, '
'only individual training or inference is supported.'
)
def _get_step_reduce_op_type(self):

View File

@ -151,7 +151,7 @@ class TestProfiler:
def test_cpu_profiler(self):
if sys.platform != 'linux':
return
self._train_with_profiler(device_target="CPU")
self._train_with_profiler(device_target="CPU", profile_memory=False)
self._check_cpu_profiling_file()
@pytest.mark.level1
@ -159,7 +159,7 @@ class TestProfiler:
@pytest.mark.env_onecard
@security_off_wrap
def test_gpu_profiler(self):
self._train_with_profiler(device_target="GPU")
self._train_with_profiler(device_target="GPU", profile_memory=False)
self._check_gpu_profiling_file()
@pytest.mark.level1
@ -168,12 +168,12 @@ class TestProfiler:
@pytest.mark.env_onecard
@security_off_wrap
def test_ascend_profiler(self):
self._train_with_profiler(device_target="Ascend")
self._train_with_profiler(device_target="Ascend", profile_memory=True)
self._check_d_profiling_file()
def _train_with_profiler(self, device_target):
def _train_with_profiler(self, device_target, profile_memory):
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
profiler = Profiler(profile_memory=True, output_path='data')
profiler = Profiler(profile_memory=profile_memory, output_path='data')
profiler_name = os.listdir(os.path.join(os.getcwd(), 'data'))[0]
self.profiler_path = os.path.join(os.getcwd(), f'data/{profiler_name}/')
ds_train = create_dataset(os.path.join(self.mnist_path, "train"))